Unverified Commit 9f81d741 authored by wangyu's avatar wangyu Committed by GitHub
Browse files

fix: fix MLA for ShardedModelLoader/RemoteModelLoader (#6287)


Signed-off-by: default avatarwangyu <wangyu.steph@bytedance.com>
parent a38c1497
...@@ -14,8 +14,7 @@ python save_remote_state.py \ ...@@ -14,8 +14,7 @@ python save_remote_state.py \
Then, the model can be loaded with Then, the model can be loaded with
llm = Engine( llm = Engine(
model_path="/path/to/save", model_path="[protocol]://[host]:[port]/[model_name]",
--remote-model-url [protocol]://[host]:[port]/[model_name],
tensor_parallel_size=8, tensor_parallel_size=8,
) )
""" """
......
...@@ -20,7 +20,7 @@ class ConnectorType(str, enum.Enum): ...@@ -20,7 +20,7 @@ class ConnectorType(str, enum.Enum):
KV = "KV" KV = "KV"
def create_remote_connector(url, device="cpu") -> BaseConnector: def create_remote_connector(url, **kwargs) -> BaseConnector:
connector_type = parse_connector_type(url) connector_type = parse_connector_type(url)
if connector_type == "redis": if connector_type == "redis":
return RedisConnector(url) return RedisConnector(url)
......
...@@ -20,9 +20,8 @@ class BaseConnector(ABC): ...@@ -20,9 +20,8 @@ class BaseConnector(ABC):
<connector_type://<host>:<port>/<model_name>/files/<filename> <connector_type://<host>:<port>/<model_name>/files/<filename>
""" """
def __init__(self, url: str, device: torch.device = "cpu"): def __init__(self, url: str):
self.url = url self.url = url
self.device = device
self.closed = False self.closed = False
self.local_dir = tempfile.mkdtemp() self.local_dir = tempfile.mkdtemp()
for sig in (signal.SIGINT, signal.SIGTERM): for sig in (signal.SIGINT, signal.SIGTERM):
......
...@@ -15,10 +15,10 @@ logger = logging.getLogger(__name__) ...@@ -15,10 +15,10 @@ logger = logging.getLogger(__name__)
class RedisConnector(BaseKVConnector): class RedisConnector(BaseKVConnector):
def __init__(self, url: str, device: torch.device = "cpu"): def __init__(self, url: str):
import redis import redis
super().__init__(url, device) super().__init__(url)
parsed_url = urlparse(url) parsed_url = urlparse(url)
self.connection = redis.Redis(host=parsed_url.hostname, port=parsed_url.port) self.connection = redis.Redis(host=parsed_url.hostname, port=parsed_url.port)
self.model_name = parsed_url.path.lstrip("/") self.model_name = parsed_url.path.lstrip("/")
......
...@@ -15,7 +15,7 @@ def create_serde(serde_type: str) -> Tuple[Serializer, Deserializer]: ...@@ -15,7 +15,7 @@ def create_serde(serde_type: str) -> Tuple[Serializer, Deserializer]:
if serde_type == "safe": if serde_type == "safe":
s = SafeSerializer() s = SafeSerializer()
d = SafeDeserializer(torch.uint8) d = SafeDeserializer()
else: else:
raise ValueError(f"Unknown serde type: {serde_type}") raise ValueError(f"Unknown serde type: {serde_type}")
......
...@@ -19,11 +19,12 @@ class SafeSerializer(Serializer): ...@@ -19,11 +19,12 @@ class SafeSerializer(Serializer):
class SafeDeserializer(Deserializer): class SafeDeserializer(Deserializer):
def __init__(self, dtype): def __init__(self):
super().__init__(dtype) # TODO: dtype options
super().__init__(torch.float32)
def from_bytes_normal(self, b: Union[bytearray, bytes]) -> torch.Tensor: def from_bytes_normal(self, b: Union[bytearray, bytes]) -> torch.Tensor:
return load(bytes(b))["tensor_bytes"].to(dtype=self.dtype) return load(bytes(b))["tensor_bytes"]
def from_bytes(self, b: Union[bytearray, bytes]) -> torch.Tensor: def from_bytes(self, b: Union[bytearray, bytes]) -> torch.Tensor:
return self.from_bytes_normal(b) return self.from_bytes_normal(b)
...@@ -42,6 +42,7 @@ from sglang.srt.distributed import ( ...@@ -42,6 +42,7 @@ from sglang.srt.distributed import (
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.model_loader.utils import ( from sglang.srt.model_loader.utils import (
get_model_architecture, get_model_architecture,
post_load_weights,
set_default_torch_dtype, set_default_torch_dtype,
) )
from sglang.srt.model_loader.weight_utils import ( from sglang.srt.model_loader.weight_utils import (
...@@ -600,18 +601,7 @@ class DummyModelLoader(BaseModelLoader): ...@@ -600,18 +601,7 @@ class DummyModelLoader(BaseModelLoader):
# random values to the weights. # random values to the weights.
initialize_dummy_weights(model) initialize_dummy_weights(model)
# Model weight loading consists of two stages: post_load_weights(model, model_config)
# 1. Initial weight loading.
# 2. Post-processing of weights, including assigning specific member variables.
# For `dummy_init`, only the second stage is required.
if hasattr(model, "post_load_weights"):
if (
model_config.hf_config.architectures[0]
== "DeepseekV3ForCausalLMNextN"
):
model.post_load_weights(is_nextn=True)
else:
model.post_load_weights()
return model.eval() return model.eval()
...@@ -751,6 +741,9 @@ class ShardedStateLoader(BaseModelLoader): ...@@ -751,6 +741,9 @@ class ShardedStateLoader(BaseModelLoader):
state_dict.pop(key) state_dict.pop(key)
if state_dict: if state_dict:
raise ValueError(f"Missing keys {tuple(state_dict)} in loaded state!") raise ValueError(f"Missing keys {tuple(state_dict)} in loaded state!")
post_load_weights(model, model_config)
return model.eval() return model.eval()
@staticmethod @staticmethod
...@@ -1421,18 +1414,16 @@ class RemoteModelLoader(BaseModelLoader): ...@@ -1421,18 +1414,16 @@ class RemoteModelLoader(BaseModelLoader):
# ignore hidden files # ignore hidden files
if file_name.startswith("."): if file_name.startswith("."):
continue continue
if os.path.splitext(file_name)[1] not in ( if os.path.splitext(file_name)[1] in (".json", ".py"):
".bin",
".pt",
".safetensors",
):
file_path = os.path.join(root, file_name) file_path = os.path.join(root, file_name)
with open(file_path, encoding="utf-8") as file: with open(file_path, encoding="utf-8") as file:
file_content = file.read() file_content = file.read()
f_key = f"{model_name}/files/{file_name}" f_key = f"{model_name}/files/{file_name}"
client.setstr(f_key, file_content) client.setstr(f_key, file_content)
def _load_model_from_remote_kv(self, model: nn.Module, client): def _load_model_from_remote_kv(
self, model: nn.Module, model_config: ModelConfig, client
):
for _, module in model.named_modules(): for _, module in model.named_modules():
quant_method = getattr(module, "quant_method", None) quant_method = getattr(module, "quant_method", None)
if quant_method is not None: if quant_method is not None:
...@@ -1460,6 +1451,8 @@ class RemoteModelLoader(BaseModelLoader): ...@@ -1460,6 +1451,8 @@ class RemoteModelLoader(BaseModelLoader):
if state_dict: if state_dict:
raise ValueError(f"Missing keys {tuple(state_dict)} in loaded state!") raise ValueError(f"Missing keys {tuple(state_dict)} in loaded state!")
post_load_weights(model, model_config)
def _load_model_from_remote_fs( def _load_model_from_remote_fs(
self, model, client, model_config: ModelConfig, device_config: DeviceConfig self, model, client, model_config: ModelConfig, device_config: DeviceConfig
) -> nn.Module: ) -> nn.Module:
...@@ -1501,15 +1494,13 @@ class RemoteModelLoader(BaseModelLoader): ...@@ -1501,15 +1494,13 @@ class RemoteModelLoader(BaseModelLoader):
with set_default_torch_dtype(model_config.dtype): with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device): with torch.device(device_config.device):
model = _initialize_model(model_config, self.load_config) model = _initialize_model(model_config, self.load_config)
for _, module in model.named_modules():
quant_method = getattr(module, "quant_method", None)
if quant_method is not None:
quant_method.process_weights_after_loading(module)
with create_remote_connector(model_weights, device_config.device) as client: with create_remote_connector(
model_weights, device=device_config.device
) as client:
connector_type = get_connector_type(client) connector_type = get_connector_type(client)
if connector_type == ConnectorType.KV: if connector_type == ConnectorType.KV:
self._load_model_from_remote_kv(model, client) self._load_model_from_remote_kv(model, model_config, client)
elif connector_type == ConnectorType.FS: elif connector_type == ConnectorType.FS:
self._load_model_from_remote_fs( self._load_model_from_remote_fs(
model, client, model_config, device_config model, client, model_config, device_config
......
...@@ -105,3 +105,15 @@ def get_model_architecture(model_config: ModelConfig) -> Tuple[Type[nn.Module], ...@@ -105,3 +105,15 @@ def get_model_architecture(model_config: ModelConfig) -> Tuple[Type[nn.Module],
def get_architecture_class_name(model_config: ModelConfig) -> str: def get_architecture_class_name(model_config: ModelConfig) -> str:
return get_model_architecture(model_config)[1] return get_model_architecture(model_config)[1]
def post_load_weights(model: nn.Module, model_config: ModelConfig):
# Model weight loading consists of two stages:
# 1. Initial weight loading.
# 2. Post-processing of weights, including assigning specific member variables.
# For `dummy_init`, only the second stage is required.
if hasattr(model, "post_load_weights"):
if model_config.hf_config.architectures[0] == "DeepseekV3ForCausalLMNextN":
model.post_load_weights(is_nextn=True)
else:
model.post_load_weights()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment