Unverified Commit 51cdd81f authored by Zilin Zhu's avatar Zilin Zhu Committed by GitHub
Browse files

[fix][RL] Fix DeepSeekV3ForCausalLM.post_load_weights for multiple update weight (#6265)

parent 73def253
......@@ -92,6 +92,7 @@ from sglang.srt.utils import (
BumpAllocator,
DeepEPMode,
add_prefix,
bind_or_assign,
get_bool_env_var,
get_int_env_var,
is_cuda,
......@@ -1713,14 +1714,23 @@ class DeepseekV2ForCausalLM(nn.Module):
input_ids, hidden_states, self.lm_head, forward_batch
)
def post_load_weights(self, is_nextn=False):
def post_load_weights(self, is_nextn=False, weight_names=None):
# Perform post-processing after loading weights
layer_ids = (
range(self.config.num_hidden_layers)
if not is_nextn
else [self.config.num_hidden_layers]
)
if is_nextn:
layer_ids = [self.config.num_hidden_layers]
else:
if weight_names is None:
layer_ids = range(self.config.num_hidden_layers)
else:
layer_ids = set()
for name in weight_names:
if "kv_b_proj" in name:
layer_id = int(name.split(".")[2])
# filter the nextn layer.
if layer_id != self.config.num_hidden_layers:
layer_ids.add(layer_id)
for layer_id in layer_ids:
self_attn = (
self.model.layers[layer_id].self_attn
......@@ -1830,13 +1840,19 @@ class DeepseekV2ForCausalLM(nn.Module):
0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
if not use_deep_gemm_bmm:
self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2)
self_attn.w_vc = w_vc.contiguous().transpose(1, 2)
self_attn.w_kc = bind_or_assign(
self_attn.w_kc, w_kc.transpose(1, 2).contiguous().transpose(1, 2)
)
self_attn.w_vc = bind_or_assign(
self_attn.w_vc, w_vc.contiguous().transpose(1, 2)
)
if (
hasattr(self_attn.kv_b_proj, "weight_scale")
and self_attn.w_scale is None
):
self_attn.w_scale = self_attn.kv_b_proj.weight_scale
self_attn.w_scale = bind_or_assign(
self_attn.w_scale, self_attn.kv_b_proj.weight_scale
)
if _is_hip:
self_attn.w_scale *= 2.0
else:
......@@ -1845,10 +1861,16 @@ class DeepseekV2ForCausalLM(nn.Module):
ws_kc, ws_vc = block_scale.unflatten(
0, (-1, (num_tiles_k + num_tiles_n))
).split([num_tiles_k, num_tiles_n], dim=1)
self_attn.w_scale_k = ws_kc.transpose(1, 2).contiguous()
self_attn.w_scale_v = ws_vc.contiguous()
self_attn.w_kc = w_kc.transpose(1, 2).contiguous()
self_attn.w_vc = w_vc.contiguous()
self_attn.w_scale_k = bind_or_assign(
self_attn.w_scale_k, ws_kc.transpose(1, 2).contiguous()
)
self_attn.w_scale_v = bind_or_assign(
self_attn.w_scale_v, ws_vc.contiguous()
)
self_attn.w_kc = bind_or_assign(
self_attn.w_kc, w_kc.transpose(1, 2).contiguous()
)
self_attn.w_vc = bind_or_assign(self_attn.w_vc, w_vc.contiguous())
self_attn.use_deep_gemm_bmm = True
# TODO support nextn later
......@@ -1958,7 +1980,10 @@ class DeepseekV2ForCausalLM(nn.Module):
]
params_dict = dict(self.named_parameters())
weight_names = []
for name, loaded_weight in weights:
weight_names.append(name)
if not is_nextn:
if hasattr(self.config, "num_nextn_predict_layers"):
num_nextn_layers = self.config.num_nextn_predict_layers
......@@ -2075,7 +2100,7 @@ class DeepseekV2ForCausalLM(nn.Module):
)
weight_loader(param, loaded_weight)
self.post_load_weights(is_nextn=is_nextn)
self.post_load_weights(is_nextn=is_nextn, weight_names=weight_names)
def get_embed_and_head(self):
return self.model.embed_tokens.weight, self.lm_head.weight
......
......@@ -2217,3 +2217,11 @@ def read_system_prompt_from_file(model_name: str) -> str:
except Exception:
# If anything fails, return empty string
return ""
def bind_or_assign(target, source):
if target is not None:
target.copy_(source)
return target
else:
return source
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment