"...api/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "5ca062e011bc7b9fd9b581a745dfeeed9e45737a"
Unverified Commit da4e8b38 authored by Hui Liu's avatar Hui Liu Committed by GitHub
Browse files

enable kv_scale remap (#3017)

parent af6c5357
...@@ -61,7 +61,10 @@ from sglang.srt.layers.radix_attention import RadixAttention ...@@ -61,7 +61,10 @@ from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.rotary_embedding import get_rope
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import (
default_weight_loader,
maybe_remap_kv_scale_name,
)
from sglang.srt.utils import get_compiler_backend, set_weight_attrs from sglang.srt.utils import get_compiler_backend, set_weight_attrs
...@@ -372,6 +375,11 @@ class CohereForCausalLM(nn.Module): ...@@ -372,6 +375,11 @@ class CohereForCausalLM(nn.Module):
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
continue continue
# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
param = params_dict[name] param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
......
...@@ -42,7 +42,10 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ...@@ -42,7 +42,10 @@ from sglang.srt.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, VocabParallelEmbedding,
) )
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import (
default_weight_loader,
maybe_remap_kv_scale_name,
)
from sglang.srt.utils import set_weight_attrs from sglang.srt.utils import set_weight_attrs
...@@ -411,6 +414,11 @@ class DbrxForCausalLM(nn.Module): ...@@ -411,6 +414,11 @@ class DbrxForCausalLM(nn.Module):
weight_loader(param, loaded_weight, weight_name) weight_loader(param, loaded_weight, weight_name)
break break
else: else:
# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
param = params_dict[name] param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment