Commit 20eac49d authored by maxiao1's avatar maxiao1
Browse files

fix bug

parent a1175a4e
...@@ -297,8 +297,10 @@ class VisionFlash3Attention(nn.Module): ...@@ -297,8 +297,10 @@ class VisionFlash3Attention(nn.Module):
self, self,
**kwargs, **kwargs,
): ):
if not _is_cuda: # if not _is_cuda:
raise Exception("VisionFlash3Attention is only available for cuda") # raise Exception("VisionFlash3Attention is only available for cuda")
if _is_hip:
from sglang.srt.layers.attention.flashattention_interface import flash_attn_varlen_func
super().__init__() super().__init__()
def forward( def forward(
......
...@@ -1389,16 +1389,16 @@ class RowParallelLinear(LinearBase): ...@@ -1389,16 +1389,16 @@ class RowParallelLinear(LinearBase):
if use_fused_silu_mul_quant: if use_fused_silu_mul_quant:
xq, xs = lm_fuse_silu_mul_quant(input_parallel) xq, xs = lm_fuse_silu_mul_quant(input_parallel)
silu_quant_args = [xq, xs] silu_quant_args = [xq, xs]
with use_symmetric_memory(parallel_state.get_tp_group()) as sm: with use_symmetric_memory(get_tp_group()) as sm:
output_parallel = self.quant_method.apply(self, input_parallel, output_parallel = self.quant_method.apply(self, input_parallel,
bias=bias_, bias=bias_,
silu_quant_args=silu_quant_args silu_quant_args=silu_quant_args
) )
sm.tag(output_parallel) # sm.tag(output_parallel)
else: else:
with use_symmetric_memory(parallel_state.get_tp_group()) as sm: with use_symmetric_memory(get_tp_group()) as sm:
output_parallel = self.quant_method.apply(self, input_parallel, bias=bias_) output_parallel = self.quant_method.apply(self, input_parallel, bias=bias_)
sm.tag(output_parallel) # sm.tag(output_parallel)
if self.reduce_results and self.tp_size > 1 and not skip_all_reduce: if self.reduce_results and self.tp_size > 1 and not skip_all_reduce:
output = tensor_model_parallel_all_reduce(output_parallel) output = tensor_model_parallel_all_reduce(output_parallel)
......
...@@ -13,7 +13,7 @@ from sglang.srt.layers.moe import ( ...@@ -13,7 +13,7 @@ from sglang.srt.layers.moe import (
get_deepep_mode, get_deepep_mode,
get_moe_a2a_backend, get_moe_a2a_backend,
get_moe_runner_backend, get_moe_runner_backend,
should_use_flashinfer_trtllm_moe, # should_use_flashinfer_trtllm_moe,
) )
from sglang.srt.layers.moe.ep_moe.kernels import ( from sglang.srt.layers.moe.ep_moe.kernels import (
ep_gather, ep_gather,
......
...@@ -362,7 +362,7 @@ class ForwardBatch: ...@@ -362,7 +362,7 @@ class ForwardBatch:
input_embeds=batch.input_embeds, input_embeds=batch.input_embeds,
token_type_ids=batch.token_type_ids, token_type_ids=batch.token_type_ids,
tbo_split_seq_index=batch.tbo_split_seq_index, tbo_split_seq_index=batch.tbo_split_seq_index,
dimensions=batch.dimensions, # dimensions=batch.dimensions,
) )
device = model_runner.device device = model_runner.device
......
...@@ -868,11 +868,12 @@ class DeepseekV2MoE(nn.Module): ...@@ -868,11 +868,12 @@ class DeepseekV2MoE(nn.Module):
# fused in biased_grouped_topk so we can skip here # fused in biased_grouped_topk so we can skip here
final_hidden_states *= self.routed_scaling_factor final_hidden_states *= self.routed_scaling_factor
if shared_output is not None: if shared_output is not None:
with use_symmetric_memory(parallel_state.get_tp_group()) as sm: # with use_symmetric_memory(parallel_state.get_tp_group()) as sm:
final_hidden_states_out = torch.empty_like(final_hidden_states) # final_hidden_states_out = torch.empty_like(final_hidden_states)
torch.add(final_hidden_states, shared_output, out=final_hidden_states_out) # torch.add(final_hidden_states, shared_output, out=final_hidden_states_out)
final_hidden_states = final_hidden_states_out # final_hidden_states = final_hidden_states_out
sm.tag(final_hidden_states) # sm.tag(final_hidden_states)
final_hidden_states += shared_output
if ( if (
self.tp_size > 1 self.tp_size > 1
and not should_allreduce_fusion and not should_allreduce_fusion
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment