Unverified Commit 9f81d741 authored by wangyu's avatar wangyu Committed by GitHub
Browse files

fix: fix MLA for ShardedModelLoader/RemoteModelLoader (#6287)


Signed-off-by: default avatarwangyu <wangyu.steph@bytedance.com>
parent a38c1497
......@@ -14,8 +14,7 @@ python save_remote_state.py \
Then, the model can be loaded with
llm = Engine(
model_path="/path/to/save",
--remote-model-url [protocol]://[host]:[port]/[model_name],
model_path="[protocol]://[host]:[port]/[model_name]",
tensor_parallel_size=8,
)
"""
......
......@@ -20,7 +20,7 @@ class ConnectorType(str, enum.Enum):
KV = "KV"
def create_remote_connector(url, device="cpu") -> BaseConnector:
def create_remote_connector(url, **kwargs) -> BaseConnector:
connector_type = parse_connector_type(url)
if connector_type == "redis":
return RedisConnector(url)
......
......@@ -20,9 +20,8 @@ class BaseConnector(ABC):
<connector_type://<host>:<port>/<model_name>/files/<filename>
"""
def __init__(self, url: str, device: torch.device = "cpu"):
def __init__(self, url: str):
self.url = url
self.device = device
self.closed = False
self.local_dir = tempfile.mkdtemp()
for sig in (signal.SIGINT, signal.SIGTERM):
......
......@@ -15,10 +15,10 @@ logger = logging.getLogger(__name__)
class RedisConnector(BaseKVConnector):
def __init__(self, url: str, device: torch.device = "cpu"):
def __init__(self, url: str):
import redis
super().__init__(url, device)
super().__init__(url)
parsed_url = urlparse(url)
self.connection = redis.Redis(host=parsed_url.hostname, port=parsed_url.port)
self.model_name = parsed_url.path.lstrip("/")
......
......@@ -15,7 +15,7 @@ def create_serde(serde_type: str) -> Tuple[Serializer, Deserializer]:
if serde_type == "safe":
s = SafeSerializer()
d = SafeDeserializer(torch.uint8)
d = SafeDeserializer()
else:
raise ValueError(f"Unknown serde type: {serde_type}")
......
......@@ -19,11 +19,12 @@ class SafeSerializer(Serializer):
class SafeDeserializer(Deserializer):
def __init__(self, dtype):
super().__init__(dtype)
def __init__(self):
# TODO: dtype options
super().__init__(torch.float32)
def from_bytes_normal(self, b: Union[bytearray, bytes]) -> torch.Tensor:
return load(bytes(b))["tensor_bytes"].to(dtype=self.dtype)
return load(bytes(b))["tensor_bytes"]
def from_bytes(self, b: Union[bytearray, bytes]) -> torch.Tensor:
return self.from_bytes_normal(b)
......@@ -42,6 +42,7 @@ from sglang.srt.distributed import (
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.model_loader.utils import (
get_model_architecture,
post_load_weights,
set_default_torch_dtype,
)
from sglang.srt.model_loader.weight_utils import (
......@@ -600,18 +601,7 @@ class DummyModelLoader(BaseModelLoader):
# random values to the weights.
initialize_dummy_weights(model)
# Model weight loading consists of two stages:
# 1. Initial weight loading.
# 2. Post-processing of weights, including assigning specific member variables.
# For `dummy_init`, only the second stage is required.
if hasattr(model, "post_load_weights"):
if (
model_config.hf_config.architectures[0]
== "DeepseekV3ForCausalLMNextN"
):
model.post_load_weights(is_nextn=True)
else:
model.post_load_weights()
post_load_weights(model, model_config)
return model.eval()
......@@ -751,6 +741,9 @@ class ShardedStateLoader(BaseModelLoader):
state_dict.pop(key)
if state_dict:
raise ValueError(f"Missing keys {tuple(state_dict)} in loaded state!")
post_load_weights(model, model_config)
return model.eval()
@staticmethod
......@@ -1421,18 +1414,16 @@ class RemoteModelLoader(BaseModelLoader):
# ignore hidden files
if file_name.startswith("."):
continue
if os.path.splitext(file_name)[1] not in (
".bin",
".pt",
".safetensors",
):
if os.path.splitext(file_name)[1] in (".json", ".py"):
file_path = os.path.join(root, file_name)
with open(file_path, encoding="utf-8") as file:
file_content = file.read()
f_key = f"{model_name}/files/{file_name}"
client.setstr(f_key, file_content)
def _load_model_from_remote_kv(self, model: nn.Module, client):
def _load_model_from_remote_kv(
self, model: nn.Module, model_config: ModelConfig, client
):
for _, module in model.named_modules():
quant_method = getattr(module, "quant_method", None)
if quant_method is not None:
......@@ -1460,6 +1451,8 @@ class RemoteModelLoader(BaseModelLoader):
if state_dict:
raise ValueError(f"Missing keys {tuple(state_dict)} in loaded state!")
post_load_weights(model, model_config)
def _load_model_from_remote_fs(
self, model, client, model_config: ModelConfig, device_config: DeviceConfig
) -> nn.Module:
......@@ -1501,15 +1494,13 @@ class RemoteModelLoader(BaseModelLoader):
with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device):
model = _initialize_model(model_config, self.load_config)
for _, module in model.named_modules():
quant_method = getattr(module, "quant_method", None)
if quant_method is not None:
quant_method.process_weights_after_loading(module)
with create_remote_connector(model_weights, device_config.device) as client:
with create_remote_connector(
model_weights, device=device_config.device
) as client:
connector_type = get_connector_type(client)
if connector_type == ConnectorType.KV:
self._load_model_from_remote_kv(model, client)
self._load_model_from_remote_kv(model, model_config, client)
elif connector_type == ConnectorType.FS:
self._load_model_from_remote_fs(
model, client, model_config, device_config
......
......@@ -105,3 +105,15 @@ def get_model_architecture(model_config: ModelConfig) -> Tuple[Type[nn.Module],
def get_architecture_class_name(model_config: ModelConfig) -> str:
return get_model_architecture(model_config)[1]
def post_load_weights(model: nn.Module, model_config: ModelConfig):
# Model weight loading consists of two stages:
# 1. Initial weight loading.
# 2. Post-processing of weights, including assigning specific member variables.
# For `dummy_init`, only the second stage is required.
if hasattr(model, "post_load_weights"):
if model_config.hf_config.architectures[0] == "DeepseekV3ForCausalLMNextN":
model.post_load_weights(is_nextn=True)
else:
model.post_load_weights()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment