Unverified Commit 38216cf0 authored by Albert's avatar Albert Committed by GitHub
Browse files

concurrently load weights of DeepseekV2ForCausalLM (#7943)


Signed-off-by: default avatarTianyu Zhou <albert.zty@antgroup.com>
parent 4a883795
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
# https://github.com/vllm-project/vllm/blob/fb6af8bc086328ca6659e72d11ffd4309ce4de22/vllm/model_executor/models/deepseek_v2.py # https://github.com/vllm-project/vllm/blob/fb6af8bc086328ca6659e72d11ffd4309ce4de22/vllm/model_executor/models/deepseek_v2.py
"""Inference-only DeepseekV2 model.""" """Inference-only DeepseekV2 model."""
import concurrent.futures
import logging import logging
import os import os
from enum import IntEnum, auto from enum import IntEnum, auto
...@@ -2436,6 +2437,8 @@ class DeepseekV2ForCausalLM(nn.Module): ...@@ -2436,6 +2437,8 @@ class DeepseekV2ForCausalLM(nn.Module):
assert self.num_fused_shared_experts == 1 assert self.num_fused_shared_experts == 1
log_info_on_rank0(logger, "Shared experts fusion optimization enabled.") log_info_on_rank0(logger, "Shared experts fusion optimization enabled.")
with concurrent.futures.ThreadPoolExecutor() as executor:
futures = []
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
weight_names = [] weight_names = []
for name, loaded_weight in weights: for name, loaded_weight in weights:
...@@ -2496,7 +2499,9 @@ class DeepseekV2ForCausalLM(nn.Module): ...@@ -2496,7 +2499,9 @@ class DeepseekV2ForCausalLM(nn.Module):
continue continue
param = params_dict[name] param = params_dict[name]
weight_loader = param.weight_loader weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id) futures.append(
executor.submit(weight_loader, param, loaded_weight, shard_id)
)
break break
else: else:
for mapping in expert_params_mapping: for mapping in expert_params_mapping:
...@@ -2506,13 +2511,16 @@ class DeepseekV2ForCausalLM(nn.Module): ...@@ -2506,13 +2511,16 @@ class DeepseekV2ForCausalLM(nn.Module):
name = name.replace(weight_name, param_name) name = name.replace(weight_name, param_name)
param = params_dict[name] param = params_dict[name]
weight_loader = param.weight_loader weight_loader = param.weight_loader
weight_loader( futures.append(
executor.submit(
weight_loader,
param, param,
loaded_weight, loaded_weight,
name, name,
shard_id=shard_id, shard_id=shard_id,
expert_id=expert_id, expert_id=expert_id,
) )
)
break break
else: else:
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
...@@ -2550,10 +2558,13 @@ class DeepseekV2ForCausalLM(nn.Module): ...@@ -2550,10 +2558,13 @@ class DeepseekV2ForCausalLM(nn.Module):
[q_a_proj_weight, kv_a_proj_weight], dim=cat_dim [q_a_proj_weight, kv_a_proj_weight], dim=cat_dim
) )
param_name = ( param_name = (
name.replace("q_a_proj", "fused_qkv_a_proj_with_mqa") name.replace(
"q_a_proj", "fused_qkv_a_proj_with_mqa"
)
if "q_a_proj" in name if "q_a_proj" in name
else name.replace( else name.replace(
"kv_a_proj_with_mqa", "fused_qkv_a_proj_with_mqa" "kv_a_proj_with_mqa",
"fused_qkv_a_proj_with_mqa",
) )
) )
param = params_dict[param_name] param = params_dict[param_name]
...@@ -2561,7 +2572,9 @@ class DeepseekV2ForCausalLM(nn.Module): ...@@ -2561,7 +2572,9 @@ class DeepseekV2ForCausalLM(nn.Module):
weight_loader = getattr( weight_loader = getattr(
param, "weight_loader", default_weight_loader param, "weight_loader", default_weight_loader
) )
weight_loader(param, fused_weight) futures.append(
executor.submit(weight_loader, param, fused_weight)
)
cached_a_proj.pop(q_a_proj_name) cached_a_proj.pop(q_a_proj_name)
cached_a_proj.pop(kv_a_proj_name) cached_a_proj.pop(kv_a_proj_name)
else: else:
...@@ -2571,7 +2584,9 @@ class DeepseekV2ForCausalLM(nn.Module): ...@@ -2571,7 +2584,9 @@ class DeepseekV2ForCausalLM(nn.Module):
# modelopt attn kv scale is named differently # modelopt attn kv scale is named differently
for scale in ["k_scale", "v_scale"]: for scale in ["k_scale", "v_scale"]:
if scale in name: if scale in name:
name = name.replace(f"{scale[0]}_proj", "attn_mqa") name = name.replace(
f"{scale[0]}_proj", "attn_mqa"
)
break break
if name not in params_dict: if name not in params_dict:
# modelopt ckpt contains not needed weights for MTP module: # modelopt ckpt contains not needed weights for MTP module:
...@@ -2583,7 +2598,13 @@ class DeepseekV2ForCausalLM(nn.Module): ...@@ -2583,7 +2598,13 @@ class DeepseekV2ForCausalLM(nn.Module):
weight_loader = getattr( weight_loader = getattr(
param, "weight_loader", default_weight_loader param, "weight_loader", default_weight_loader
) )
weight_loader(param, loaded_weight) futures.append(
executor.submit(weight_loader, param, loaded_weight)
)
# Wait for all tasks to complete and raise any exceptions.
for future in concurrent.futures.as_completed(futures):
future.result()
self.post_load_weights(is_nextn=is_nextn, weight_names=weight_names) self.post_load_weights(is_nextn=is_nextn, weight_names=weight_names)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment