Unverified Commit b18416fb authored by Yi Zhang's avatar Yi Zhang Committed by GitHub
Browse files

Fix qwen3 tbo/dp-lm-head (#6652)

parent ce9d690e
......@@ -501,6 +501,7 @@ class Qwen2MoeForCausalLM(nn.Module):
config.hidden_size,
quant_config=quant_config,
prefix=add_prefix("lm_head", prefix),
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
)
self.logits_processor = LogitsProcessor(config)
......
......@@ -688,6 +688,7 @@ class Qwen3MoeForCausalLM(nn.Module):
config.hidden_size,
quant_config=quant_config,
prefix=add_prefix("lm_head", prefix),
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
)
self.logits_processor = LogitsProcessor(config)
......
......@@ -370,7 +370,7 @@ def model_forward_maybe_tbo(
hidden_states=hidden_states,
forward_batch=forward_batch,
residual=residual,
**(dict(zero_allocator=zero_allocator) if zero_allocator is not None else {}),
zero_allocator=zero_allocator,
)
layer_input_scatter_mode = layers[0].layer_scatter_modes.layer_input_mode
operations_strategy = OperationsStrategy.init_new_tbo(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment