Unverified Commit c10b8e6a authored by Nicolas Castet's avatar Nicolas Castet Committed by GitHub
Browse files

Support DP attention with GPT-OSS (#9359)

parent d4bce297
...@@ -1091,7 +1091,7 @@ class GptOssForCausalLM(nn.Module): ...@@ -1091,7 +1091,7 @@ class GptOssForCausalLM(nn.Module):
if name in params_dict.keys(): if name in params_dict.keys():
param = params_dict[name] param = params_dict[name]
if "sinks" in name: if "sinks" in name:
start = tp_rank * param.numel() start = get_attention_tp_rank() * param.numel()
param.data.copy_( param.data.copy_(
loaded_weight[start : start + param.numel()] loaded_weight[start : start + param.numel()]
) )
......
...@@ -2183,6 +2183,7 @@ class ServerArgs: ...@@ -2183,6 +2183,7 @@ class ServerArgs:
), f"GptOssForCausalLM requires one of {supported_backends} attention backend, but got '{self.attention_backend}'" ), f"GptOssForCausalLM requires one of {supported_backends} attention backend, but got '{self.attention_backend}'"
if is_sm100_supported(): if is_sm100_supported():
if not self.enable_dp_attention:
self.enable_flashinfer_allreduce_fusion = True self.enable_flashinfer_allreduce_fusion = True
logger.info( logger.info(
"Enable FlashInfer AllReduce Fusion on sm100 for GptOssForCausalLM" "Enable FlashInfer AllReduce Fusion on sm100 for GptOssForCausalLM"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment