Commit 3f9af065 authored by zhuwenwen's avatar zhuwenwen
Browse files

fix qwen3-235b run error and remove update_environment_variables of hip

parent 3c318dbe
...@@ -52,9 +52,6 @@ class MultiprocessingDistributedExecutor(DistributedExecutorBase): ...@@ -52,9 +52,6 @@ class MultiprocessingDistributedExecutor(DistributedExecutorBase):
update_environment_variables({ update_environment_variables({
"CUDA_VISIBLE_DEVICES": (",".join(map(str, range(world_size)))) "CUDA_VISIBLE_DEVICES": (",".join(map(str, range(world_size))))
}) })
update_environment_variables({
"HIP_VISIBLE_DEVICES": (",".join(map(str, range(world_size))))
})
def _init_executor(self) -> None: def _init_executor(self) -> None:
......
...@@ -89,7 +89,8 @@ def get_model_architecture( ...@@ -89,7 +89,8 @@ def get_model_architecture(
model_config: ModelConfig) -> Tuple[Type[nn.Module], str]: model_config: ModelConfig) -> Tuple[Type[nn.Module], str]:
architectures = getattr(model_config.hf_config, "architectures", []) architectures = getattr(model_config.hf_config, "architectures", [])
visions = getattr(model_config.hf_config, "visual", []) or getattr(model_config.hf_config, "vision_config", []) visions = getattr(model_config.hf_config, "visual", []) or getattr(model_config.hf_config, "vision_config", [])
support_nn_architectures = ['LlamaForCausalLM', 'QWenLMHeadModel', 'Qwen2ForCausalLM', 'Qwen2VLForConditionalGeneration', 'Qwen2_5_VLForConditionalGeneration', 'Qwen2MoeForCausalLM', 'Qwen3ForCausalLM', 'Qwen3MoeForCausalLM', # TODO: 'Qwen2_5_VLForConditionalGeneration',
support_nn_architectures = ['LlamaForCausalLM', 'QWenLMHeadModel', 'Qwen2ForCausalLM', 'Qwen2VLForConditionalGeneration', 'Qwen2MoeForCausalLM', 'Qwen3ForCausalLM', 'Qwen3MoeForCausalLM',
'ChatGLMModel', 'Glm4ForCausalLM', 'ChatGLMForConditionalGeneration', 'BaichuanForCausalLM', 'BloomForCausalLM', 'TeleChat2ForCausalLM', 'MixtralForCausalLM', 'FalconForCausalLM', 'ChatGLMModel', 'Glm4ForCausalLM', 'ChatGLMForConditionalGeneration', 'BaichuanForCausalLM', 'BloomForCausalLM', 'TeleChat2ForCausalLM', 'MixtralForCausalLM', 'FalconForCausalLM',
'MedusaModel', 'MLPSpeculatorPreTrainedModel', 'DeepseekV2ForCausalLM', 'DeepseekV3ForCausalLM', 'DeepSeekMTPModel'] 'MedusaModel', 'MLPSpeculatorPreTrainedModel', 'DeepseekV2ForCausalLM', 'DeepseekV3ForCausalLM', 'DeepSeekMTPModel']
if any(arch in architectures for arch in support_nn_architectures): if any(arch in architectures for arch in support_nn_architectures):
......
...@@ -414,6 +414,9 @@ class Qwen3MoeModel(nn.Module): ...@@ -414,6 +414,9 @@ class Qwen3MoeModel(nn.Module):
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set() loaded_params: Set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
if self.use_llama_nn:
current_count = loaded_weight.current_count
total_count = loaded_weight.total_count
for (param_name, weight_name, shard_id) in stacked_params_mapping: for (param_name, weight_name, shard_id) in stacked_params_mapping:
# Skip non-stacked layers and experts (experts handled below). # Skip non-stacked layers and experts (experts handled below).
if weight_name not in name: if weight_name not in name:
...@@ -490,7 +493,7 @@ class Qwen3MoeModel(nn.Module): ...@@ -490,7 +493,7 @@ class Qwen3MoeModel(nn.Module):
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
loaded_params.add(name) loaded_params.add(name)
if self.use_llama_nn and self.quant_method is None: if self.use_llama_nn and self.quant_method is None and current_count==total_count:
lay_key_words = [ lay_key_words = [
"gate_up_proj.weight", "gate_up_proj.weight",
"down_proj.weight", "down_proj.weight",
......
...@@ -690,7 +690,7 @@ def update_environment_variables(envs: dict[str, str]): ...@@ -690,7 +690,7 @@ def update_environment_variables(envs: dict[str, str]):
logger.warning( logger.warning(
"Overwriting environment variable %s " "Overwriting environment variable %s "
"from '%s' to '%s'", k, os.environ[k], v) "from '%s' to '%s'", k, os.environ[k], v)
# os.environ[k] = v os.environ[k] = v
def chunk_list(lst: list[T], chunk_size: int): def chunk_list(lst: list[T], chunk_size: int):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment