Unverified Commit b18416fb authored by Yi Zhang's avatar Yi Zhang Committed by GitHub
Browse files

Fix qwen3 tbo/dp-lm-head (#6652)

parent ce9d690e
...@@ -501,6 +501,7 @@ class Qwen2MoeForCausalLM(nn.Module): ...@@ -501,6 +501,7 @@ class Qwen2MoeForCausalLM(nn.Module):
config.hidden_size, config.hidden_size,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("lm_head", prefix), prefix=add_prefix("lm_head", prefix),
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
) )
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
......
...@@ -688,6 +688,7 @@ class Qwen3MoeForCausalLM(nn.Module): ...@@ -688,6 +688,7 @@ class Qwen3MoeForCausalLM(nn.Module):
config.hidden_size, config.hidden_size,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("lm_head", prefix), prefix=add_prefix("lm_head", prefix),
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
) )
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
......
...@@ -370,7 +370,7 @@ def model_forward_maybe_tbo( ...@@ -370,7 +370,7 @@ def model_forward_maybe_tbo(
hidden_states=hidden_states, hidden_states=hidden_states,
forward_batch=forward_batch, forward_batch=forward_batch,
residual=residual, residual=residual,
**(dict(zero_allocator=zero_allocator) if zero_allocator is not None else {}), zero_allocator=zero_allocator,
) )
layer_input_scatter_mode = layers[0].layer_scatter_modes.layer_input_mode layer_input_scatter_mode = layers[0].layer_scatter_modes.layer_input_mode
operations_strategy = OperationsStrategy.init_new_tbo( operations_strategy = OperationsStrategy.init_new_tbo(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment