Unverified Commit 51cdd81f authored by Zilin Zhu's avatar Zilin Zhu Committed by GitHub
Browse files

[fix][RL] Fix DeepSeekV3ForCausalLM.post_load_weights for multiple update weight (#6265)

parent 73def253
...@@ -92,6 +92,7 @@ from sglang.srt.utils import ( ...@@ -92,6 +92,7 @@ from sglang.srt.utils import (
BumpAllocator, BumpAllocator,
DeepEPMode, DeepEPMode,
add_prefix, add_prefix,
bind_or_assign,
get_bool_env_var, get_bool_env_var,
get_int_env_var, get_int_env_var,
is_cuda, is_cuda,
...@@ -1713,14 +1714,23 @@ class DeepseekV2ForCausalLM(nn.Module): ...@@ -1713,14 +1714,23 @@ class DeepseekV2ForCausalLM(nn.Module):
input_ids, hidden_states, self.lm_head, forward_batch input_ids, hidden_states, self.lm_head, forward_batch
) )
def post_load_weights(self, is_nextn=False): def post_load_weights(self, is_nextn=False, weight_names=None):
# Perform post-processing after loading weights # Perform post-processing after loading weights
layer_ids = ( if is_nextn:
range(self.config.num_hidden_layers) layer_ids = [self.config.num_hidden_layers]
if not is_nextn else:
else [self.config.num_hidden_layers] if weight_names is None:
) layer_ids = range(self.config.num_hidden_layers)
else:
layer_ids = set()
for name in weight_names:
if "kv_b_proj" in name:
layer_id = int(name.split(".")[2])
# filter the nextn layer.
if layer_id != self.config.num_hidden_layers:
layer_ids.add(layer_id)
for layer_id in layer_ids: for layer_id in layer_ids:
self_attn = ( self_attn = (
self.model.layers[layer_id].self_attn self.model.layers[layer_id].self_attn
...@@ -1830,13 +1840,19 @@ class DeepseekV2ForCausalLM(nn.Module): ...@@ -1830,13 +1840,19 @@ class DeepseekV2ForCausalLM(nn.Module):
0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim) 0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1) ).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
if not use_deep_gemm_bmm: if not use_deep_gemm_bmm:
self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2) self_attn.w_kc = bind_or_assign(
self_attn.w_vc = w_vc.contiguous().transpose(1, 2) self_attn.w_kc, w_kc.transpose(1, 2).contiguous().transpose(1, 2)
)
self_attn.w_vc = bind_or_assign(
self_attn.w_vc, w_vc.contiguous().transpose(1, 2)
)
if ( if (
hasattr(self_attn.kv_b_proj, "weight_scale") hasattr(self_attn.kv_b_proj, "weight_scale")
and self_attn.w_scale is None and self_attn.w_scale is None
): ):
self_attn.w_scale = self_attn.kv_b_proj.weight_scale self_attn.w_scale = bind_or_assign(
self_attn.w_scale, self_attn.kv_b_proj.weight_scale
)
if _is_hip: if _is_hip:
self_attn.w_scale *= 2.0 self_attn.w_scale *= 2.0
else: else:
...@@ -1845,10 +1861,16 @@ class DeepseekV2ForCausalLM(nn.Module): ...@@ -1845,10 +1861,16 @@ class DeepseekV2ForCausalLM(nn.Module):
ws_kc, ws_vc = block_scale.unflatten( ws_kc, ws_vc = block_scale.unflatten(
0, (-1, (num_tiles_k + num_tiles_n)) 0, (-1, (num_tiles_k + num_tiles_n))
).split([num_tiles_k, num_tiles_n], dim=1) ).split([num_tiles_k, num_tiles_n], dim=1)
self_attn.w_scale_k = ws_kc.transpose(1, 2).contiguous() self_attn.w_scale_k = bind_or_assign(
self_attn.w_scale_v = ws_vc.contiguous() self_attn.w_scale_k, ws_kc.transpose(1, 2).contiguous()
self_attn.w_kc = w_kc.transpose(1, 2).contiguous() )
self_attn.w_vc = w_vc.contiguous() self_attn.w_scale_v = bind_or_assign(
self_attn.w_scale_v, ws_vc.contiguous()
)
self_attn.w_kc = bind_or_assign(
self_attn.w_kc, w_kc.transpose(1, 2).contiguous()
)
self_attn.w_vc = bind_or_assign(self_attn.w_vc, w_vc.contiguous())
self_attn.use_deep_gemm_bmm = True self_attn.use_deep_gemm_bmm = True
# TODO support nextn later # TODO support nextn later
...@@ -1958,7 +1980,10 @@ class DeepseekV2ForCausalLM(nn.Module): ...@@ -1958,7 +1980,10 @@ class DeepseekV2ForCausalLM(nn.Module):
] ]
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
weight_names = []
for name, loaded_weight in weights: for name, loaded_weight in weights:
weight_names.append(name)
if not is_nextn: if not is_nextn:
if hasattr(self.config, "num_nextn_predict_layers"): if hasattr(self.config, "num_nextn_predict_layers"):
num_nextn_layers = self.config.num_nextn_predict_layers num_nextn_layers = self.config.num_nextn_predict_layers
...@@ -2075,7 +2100,7 @@ class DeepseekV2ForCausalLM(nn.Module): ...@@ -2075,7 +2100,7 @@ class DeepseekV2ForCausalLM(nn.Module):
) )
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
self.post_load_weights(is_nextn=is_nextn) self.post_load_weights(is_nextn=is_nextn, weight_names=weight_names)
def get_embed_and_head(self): def get_embed_and_head(self):
return self.model.embed_tokens.weight, self.lm_head.weight return self.model.embed_tokens.weight, self.lm_head.weight
......
...@@ -2217,3 +2217,11 @@ def read_system_prompt_from_file(model_name: str) -> str: ...@@ -2217,3 +2217,11 @@ def read_system_prompt_from_file(model_name: str) -> str:
except Exception: except Exception:
# If anything fails, return empty string # If anything fails, return empty string
return "" return ""
def bind_or_assign(target, source):
if target is not None:
target.copy_(source)
return target
else:
return source
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment