Commit 3f9af065 authored by zhuwenwen's avatar zhuwenwen
Browse files

fix qwen3-235b run error and remove update_environment_variables of hip

parent 3c318dbe
......@@ -52,9 +52,6 @@ class MultiprocessingDistributedExecutor(DistributedExecutorBase):
update_environment_variables({
"CUDA_VISIBLE_DEVICES": (",".join(map(str, range(world_size))))
})
update_environment_variables({
"HIP_VISIBLE_DEVICES": (",".join(map(str, range(world_size))))
})
def _init_executor(self) -> None:
......
......@@ -89,7 +89,8 @@ def get_model_architecture(
model_config: ModelConfig) -> Tuple[Type[nn.Module], str]:
architectures = getattr(model_config.hf_config, "architectures", [])
visions = getattr(model_config.hf_config, "visual", []) or getattr(model_config.hf_config, "vision_config", [])
support_nn_architectures = ['LlamaForCausalLM', 'QWenLMHeadModel', 'Qwen2ForCausalLM', 'Qwen2VLForConditionalGeneration', 'Qwen2_5_VLForConditionalGeneration', 'Qwen2MoeForCausalLM', 'Qwen3ForCausalLM', 'Qwen3MoeForCausalLM',
# TODO: 'Qwen2_5_VLForConditionalGeneration',
support_nn_architectures = ['LlamaForCausalLM', 'QWenLMHeadModel', 'Qwen2ForCausalLM', 'Qwen2VLForConditionalGeneration', 'Qwen2MoeForCausalLM', 'Qwen3ForCausalLM', 'Qwen3MoeForCausalLM',
'ChatGLMModel', 'Glm4ForCausalLM', 'ChatGLMForConditionalGeneration', 'BaichuanForCausalLM', 'BloomForCausalLM', 'TeleChat2ForCausalLM', 'MixtralForCausalLM', 'FalconForCausalLM',
'MedusaModel', 'MLPSpeculatorPreTrainedModel', 'DeepseekV2ForCausalLM', 'DeepseekV3ForCausalLM', 'DeepSeekMTPModel']
if any(arch in architectures for arch in support_nn_architectures):
......
......@@ -414,6 +414,9 @@ class Qwen3MoeModel(nn.Module):
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
if self.use_llama_nn:
current_count = loaded_weight.current_count
total_count = loaded_weight.total_count
for (param_name, weight_name, shard_id) in stacked_params_mapping:
# Skip non-stacked layers and experts (experts handled below).
if weight_name not in name:
......@@ -490,7 +493,7 @@ class Qwen3MoeModel(nn.Module):
weight_loader(param, loaded_weight)
loaded_params.add(name)
if self.use_llama_nn and self.quant_method is None:
if self.use_llama_nn and self.quant_method is None and current_count==total_count:
lay_key_words = [
"gate_up_proj.weight",
"down_proj.weight",
......
......@@ -690,7 +690,7 @@ def update_environment_variables(envs: dict[str, str]):
logger.warning(
"Overwriting environment variable %s "
"from '%s' to '%s'", k, os.environ[k], v)
# os.environ[k] = v
os.environ[k] = v
def chunk_list(lst: list[T], chunk_size: int):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment