Unverified Commit 38216cf0 authored by Albert's avatar Albert Committed by GitHub
Browse files

concurrently load weights of DeepseekV2ForCausalLM (#7943)


Signed-off-by: default avatarTianyu Zhou <albert.zty@antgroup.com>
parent 4a883795
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
# https://github.com/vllm-project/vllm/blob/fb6af8bc086328ca6659e72d11ffd4309ce4de22/vllm/model_executor/models/deepseek_v2.py # https://github.com/vllm-project/vllm/blob/fb6af8bc086328ca6659e72d11ffd4309ce4de22/vllm/model_executor/models/deepseek_v2.py
"""Inference-only DeepseekV2 model.""" """Inference-only DeepseekV2 model."""
import concurrent.futures
import logging import logging
import os import os
from enum import IntEnum, auto from enum import IntEnum, auto
...@@ -2436,154 +2437,174 @@ class DeepseekV2ForCausalLM(nn.Module): ...@@ -2436,154 +2437,174 @@ class DeepseekV2ForCausalLM(nn.Module):
assert self.num_fused_shared_experts == 1 assert self.num_fused_shared_experts == 1
log_info_on_rank0(logger, "Shared experts fusion optimization enabled.") log_info_on_rank0(logger, "Shared experts fusion optimization enabled.")
params_dict = dict(self.named_parameters()) with concurrent.futures.ThreadPoolExecutor() as executor:
weight_names = [] futures = []
for name, loaded_weight in weights: params_dict = dict(self.named_parameters())
if self.num_fused_shared_experts > 0 and "mlp.shared_experts" in name: weight_names = []
name = name.replace( for name, loaded_weight in weights:
"mlp.shared_experts", if self.num_fused_shared_experts > 0 and "mlp.shared_experts" in name:
f"mlp.experts.{self.config.n_routed_experts}", name = name.replace(
) "mlp.shared_experts",
f"mlp.experts.{self.config.n_routed_experts}",
)
weight_names.append(name) weight_names.append(name)
if not is_nextn: if not is_nextn:
if hasattr(self.config, "num_nextn_predict_layers"): if hasattr(self.config, "num_nextn_predict_layers"):
num_nextn_layers = self.config.num_nextn_predict_layers num_nextn_layers = self.config.num_nextn_predict_layers
if num_nextn_layers > 0 and name.startswith("model.layers"): if num_nextn_layers > 0 and name.startswith("model.layers"):
name_list = name.split(".") name_list = name.split(".")
if ( if (
len(name_list) >= 3 len(name_list) >= 3
and int(name_list[2]) >= self.config.num_hidden_layers and int(name_list[2]) >= self.config.num_hidden_layers
): ):
continue continue
else: else:
if not name.startswith(nextn_layer_prefix): if not name.startswith(nextn_layer_prefix):
continue continue
# Use shared head and embed weights from target model # Use shared head and embed weights from target model
if "shared_head.head" in name or "embed_tokens" in name: if "shared_head.head" in name or "embed_tokens" in name:
continue continue
is_decoder = True is_decoder = True
# For nextn specific weights # For nextn specific weights
for weight_name in nextn_spec_weight_names: for weight_name in nextn_spec_weight_names:
if weight_name in name: if weight_name in name:
name = name.replace(nextn_layer_prefix, "model") name = name.replace(nextn_layer_prefix, "model")
is_decoder = False is_decoder = False
break break
# For decoder layer weights # For decoder layer weights
if is_decoder: if is_decoder:
name = name.replace(nextn_layer_prefix, "model.decoder") name = name.replace(nextn_layer_prefix, "model.decoder")
if "rotary_emb.inv_freq" in name: if "rotary_emb.inv_freq" in name:
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
# Skip non-stacked layers and experts (experts handled below).
if weight_name not in name:
continue continue
# We have mlp.experts[0].gate_proj in the checkpoint. for param_name, weight_name, shard_id in stacked_params_mapping:
# Since we handle the experts below in expert_params_mapping, # Skip non-stacked layers and experts (experts handled below).
# we need to skip here BEFORE we update the name, otherwise
# name will be updated to mlp.experts[0].gate_up_proj, which
# will then be updated below in expert_params_mapping
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
if ("mlp.experts." in name) and name not in params_dict:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name: if weight_name not in name:
continue continue
# We have mlp.experts[0].gate_proj in the checkpoint.
# Since we handle the experts below in expert_params_mapping,
# we need to skip here BEFORE we update the name, otherwise
# name will be updated to mlp.experts[0].gate_up_proj, which
# will then be updated below in expert_params_mapping
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
if ("mlp.experts." in name) and name not in params_dict:
continue
name = name.replace(weight_name, param_name) name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name] param = params_dict[name]
weight_loader = param.weight_loader weight_loader = param.weight_loader
weight_loader( futures.append(
param, executor.submit(weight_loader, param, loaded_weight, shard_id)
loaded_weight,
name,
shard_id=shard_id,
expert_id=expert_id,
) )
break break
else: else:
# Skip loading extra bias for GPTQ models. for mapping in expert_params_mapping:
if name.endswith(".bias") and name not in params_dict: param_name, weight_name, expert_id, shard_id = mapping
continue if weight_name not in name:
if fuse_qkv_a_proj and ( continue
"q_a_proj" in name or "kv_a_proj_with_mqa" in name name = name.replace(weight_name, param_name)
): param = params_dict[name]
cached_a_proj[name] = loaded_weight weight_loader = param.weight_loader
q_a_proj_name = ( futures.append(
name executor.submit(
if "q_a_proj" in name weight_loader,
else name.replace("kv_a_proj_with_mqa", "q_a_proj") param,
) loaded_weight,
kv_a_proj_name = ( name,
name shard_id=shard_id,
if "kv_a_proj_with_mqa" in name expert_id=expert_id,
else name.replace("q_a_proj", "kv_a_proj_with_mqa") )
) )
break
# When both q_a_proj and kv_a_proj_with_mqa has been cached, load the fused weight to parameter else:
if ( # Skip loading extra bias for GPTQ models.
q_a_proj_name in cached_a_proj if name.endswith(".bias") and name not in params_dict:
and kv_a_proj_name in cached_a_proj continue
if fuse_qkv_a_proj and (
"q_a_proj" in name or "kv_a_proj_with_mqa" in name
): ):
q_a_proj_weight = cached_a_proj[q_a_proj_name] cached_a_proj[name] = loaded_weight
kv_a_proj_weight = cached_a_proj[kv_a_proj_name] q_a_proj_name = (
cat_dim = 0 name
if self.quant_config is not None and (
self.quant_config.get_name() == "awq"
or self.quant_config.get_name() == "moe_wna16"
):
cat_dim = 1
fused_weight = torch.cat(
[q_a_proj_weight, kv_a_proj_weight], dim=cat_dim
)
param_name = (
name.replace("q_a_proj", "fused_qkv_a_proj_with_mqa")
if "q_a_proj" in name if "q_a_proj" in name
else name.replace( else name.replace("kv_a_proj_with_mqa", "q_a_proj")
"kv_a_proj_with_mqa", "fused_qkv_a_proj_with_mqa" )
) kv_a_proj_name = (
name
if "kv_a_proj_with_mqa" in name
else name.replace("q_a_proj", "kv_a_proj_with_mqa")
) )
param = params_dict[param_name]
# When both q_a_proj and kv_a_proj_with_mqa has been cached, load the fused weight to parameter
if (
q_a_proj_name in cached_a_proj
and kv_a_proj_name in cached_a_proj
):
q_a_proj_weight = cached_a_proj[q_a_proj_name]
kv_a_proj_weight = cached_a_proj[kv_a_proj_name]
cat_dim = 0
if self.quant_config is not None and (
self.quant_config.get_name() == "awq"
or self.quant_config.get_name() == "moe_wna16"
):
cat_dim = 1
fused_weight = torch.cat(
[q_a_proj_weight, kv_a_proj_weight], dim=cat_dim
)
param_name = (
name.replace(
"q_a_proj", "fused_qkv_a_proj_with_mqa"
)
if "q_a_proj" in name
else name.replace(
"kv_a_proj_with_mqa",
"fused_qkv_a_proj_with_mqa",
)
)
param = params_dict[param_name]
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
futures.append(
executor.submit(weight_loader, param, fused_weight)
)
cached_a_proj.pop(q_a_proj_name)
cached_a_proj.pop(kv_a_proj_name)
else:
if (
"k_scale" in name or "v_scale" in name
) and name not in params_dict:
# modelopt attn kv scale is named differently
for scale in ["k_scale", "v_scale"]:
if scale in name:
name = name.replace(
f"{scale[0]}_proj", "attn_mqa"
)
break
if name not in params_dict:
# modelopt ckpt contains not needed weights for MTP module:
# model.decoder.self_attn.attn_mqa.v_scale and
# model.decoder.self_attn.attn_mqa.k_scale
logger.warning(f"{name} not found in params_dict.")
continue
param = params_dict[name]
weight_loader = getattr( weight_loader = getattr(
param, "weight_loader", default_weight_loader param, "weight_loader", default_weight_loader
) )
weight_loader(param, fused_weight) futures.append(
cached_a_proj.pop(q_a_proj_name) executor.submit(weight_loader, param, loaded_weight)
cached_a_proj.pop(kv_a_proj_name) )
else:
if ( # Wait for all tasks to complete and raise any exceptions.
"k_scale" in name or "v_scale" in name for future in concurrent.futures.as_completed(futures):
) and name not in params_dict: future.result()
# modelopt attn kv scale is named differently
for scale in ["k_scale", "v_scale"]:
if scale in name:
name = name.replace(f"{scale[0]}_proj", "attn_mqa")
break
if name not in params_dict:
# modelopt ckpt contains not needed weights for MTP module:
# model.decoder.self_attn.attn_mqa.v_scale and
# model.decoder.self_attn.attn_mqa.k_scale
logger.warning(f"{name} not found in params_dict.")
continue
param = params_dict[name]
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
weight_loader(param, loaded_weight)
self.post_load_weights(is_nextn=is_nextn, weight_names=weight_names) self.post_load_weights(is_nextn=is_nextn, weight_names=weight_names)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment