Unverified Commit 7d1485d3 authored by Chayenne's avatar Chayenne Committed by GitHub
Browse files

Add get weights by parameter name for llama (#2266)

parent 7d5d1d3d
...@@ -128,6 +128,7 @@ jobs: ...@@ -128,6 +128,7 @@ jobs:
python3 test_mla_fp8.py python3 test_mla_fp8.py
python3 test_dp_attention.py python3 test_dp_attention.py
performance-test-1-gpu-part-1: performance-test-1-gpu-part-1:
if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request' if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
runs-on: 1-gpu-runner runs-on: 1-gpu-runner
...@@ -242,6 +243,7 @@ jobs: ...@@ -242,6 +243,7 @@ jobs:
cd test/srt cd test/srt
python3 test_eval_accuracy_large.py python3 test_eval_accuracy_large.py
accuracy-test-2-gpu: accuracy-test-2-gpu:
if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request' if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
runs-on: 2-gpu-runner runs-on: 2-gpu-runner
......
...@@ -421,5 +421,5 @@ index 62d1ff9..6ecd78c 100644 ...@@ -421,5 +421,5 @@ index 62d1ff9..6ecd78c 100644
3. Modify the included server.sh by removing "loadTracer.sh" before python command and launch script ./server.sh in one terminal inside the docker container. 3. Modify the included server.sh by removing "loadTracer.sh" before python command and launch script ./server.sh in one terminal inside the docker container.
4. Similar to step 6 in RPD profiling section, but remove the last 2 lines in client.sh, which converted rpd file into csv and json files. Run modified client.sh for PyTorch profiling. 4. Similar to step 6 in RPD profiling section, but remove the last 2 lines in client.sh, which converted rpd file into csv and json files. Run modified client.sh for PyTorch profiling.
======= -------
- [Torch Profiler](https://pytorch.org/tutorials/recipes/recipes/profiler_recipe.html) - [Torch Profiler](https://pytorch.org/tutorials/recipes/recipes/profiler_recipe.html)
...@@ -365,6 +365,17 @@ class UpdateWeightFromDiskReqOutput: ...@@ -365,6 +365,17 @@ class UpdateWeightFromDiskReqOutput:
message: str message: str
@dataclass
class GetWeightsByNameReqInput:
name: str
truncate_size: int = 100
@dataclass
class GetWeightsByNameReqOutput:
parameter: list
@dataclass @dataclass
class AbortReq: class AbortReq:
# The request id # The request id
......
...@@ -38,6 +38,8 @@ from sglang.srt.managers.io_struct import ( ...@@ -38,6 +38,8 @@ from sglang.srt.managers.io_struct import (
BatchTokenIDOut, BatchTokenIDOut,
CloseSessionReqInput, CloseSessionReqInput,
FlushCacheReq, FlushCacheReq,
GetWeightsByNameReqInput,
GetWeightsByNameReqOutput,
OpenSessionReqInput, OpenSessionReqInput,
OpenSessionReqOutput, OpenSessionReqOutput,
ProfileReq, ProfileReq,
...@@ -511,6 +513,9 @@ class Scheduler: ...@@ -511,6 +513,9 @@ class Scheduler:
self.send_to_tokenizer.send_pyobj( self.send_to_tokenizer.send_pyobj(
UpdateWeightFromDiskReqOutput(success, message) UpdateWeightFromDiskReqOutput(success, message)
) )
elif isinstance(recv_req, GetWeightsByNameReqInput):
parameter = self.get_weights_by_name(recv_req)
self.send_to_tokenizer.send_pyobj(GetWeightsByNameReqOutput(parameter))
elif isinstance(recv_req, ProfileReq): elif isinstance(recv_req, ProfileReq):
if recv_req == ProfileReq.START_PROFILE: if recv_req == ProfileReq.START_PROFILE:
self.start_profile() self.start_profile()
...@@ -1373,6 +1378,10 @@ class Scheduler: ...@@ -1373,6 +1378,10 @@ class Scheduler:
logger.error(message) logger.error(message)
return success, message return success, message
def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
parameter = self.tp_worker.get_weights_by_name(recv_req)
return parameter
def start_profile(self) -> None: def start_profile(self) -> None:
if self.profiler is None: if self.profiler is None:
raise RuntimeError("Profiler is not enabled.") raise RuntimeError("Profiler is not enabled.")
......
...@@ -46,6 +46,8 @@ from sglang.srt.managers.io_struct import ( ...@@ -46,6 +46,8 @@ from sglang.srt.managers.io_struct import (
EmbeddingReqInput, EmbeddingReqInput,
FlushCacheReq, FlushCacheReq,
GenerateReqInput, GenerateReqInput,
GetWeightsByNameReqInput,
GetWeightsByNameReqOutput,
OpenSessionReqInput, OpenSessionReqInput,
OpenSessionReqOutput, OpenSessionReqOutput,
ProfileReq, ProfileReq,
...@@ -454,6 +456,23 @@ class TokenizerManager: ...@@ -454,6 +456,23 @@ class TokenizerManager:
else: else:
return False, "Another update is in progress. Please try again later." return False, "Another update is in progress. Please try again later."
async def get_weights_by_name(
self, obj: GetWeightsByNameReqInput, request: Optional[fastapi.Request] = None
):
if self.to_create_loop:
self.create_handle_loop()
self.send_to_scheduler.send_pyobj(obj)
self.get_weights_by_name_result = asyncio.Future()
if self.server_args.dp_size == 1:
result = await self.get_weights_by_name_result
return result.parameter
else:
self.get_weights_by_name_tmp = []
result = await self.get_weights_by_name_result
all_parameters = [r.parameter for r in result]
return all_parameters
async def open_session( async def open_session(
self, obj: OpenSessionReqInput, request: Optional[fastapi.Request] = None self, obj: OpenSessionReqInput, request: Optional[fastapi.Request] = None
): ):
...@@ -527,6 +546,7 @@ class TokenizerManager: ...@@ -527,6 +546,7 @@ class TokenizerManager:
BatchEmbeddingOut, BatchEmbeddingOut,
BatchTokenIDOut, BatchTokenIDOut,
UpdateWeightFromDiskReqOutput, UpdateWeightFromDiskReqOutput,
GetWeightsByNameReqOutput,
] = await self.recv_from_detokenizer.recv_pyobj() ] = await self.recv_from_detokenizer.recv_pyobj()
if isinstance(recv_obj, UpdateWeightFromDiskReqOutput): if isinstance(recv_obj, UpdateWeightFromDiskReqOutput):
...@@ -538,6 +558,16 @@ class TokenizerManager: ...@@ -538,6 +558,16 @@ class TokenizerManager:
if len(self.model_update_tmp) == self.server_args.dp_size: if len(self.model_update_tmp) == self.server_args.dp_size:
self.model_update_result.set_result(self.model_update_tmp) self.model_update_result.set_result(self.model_update_tmp)
continue continue
elif isinstance(recv_obj, GetWeightsByNameReqOutput):
if self.server_args.dp_size == 1:
self.get_weights_by_name_result.set_result(recv_obj)
else:
self.get_weights_by_name_tmp.append(recv_obj)
if len(self.get_weights_by_name_tmp) == self.server_args.dp_size:
self.get_weights_by_name_result.set_result(
self.get_weights_by_name_tmp
)
continue
elif isinstance(recv_obj, OpenSessionReqOutput): elif isinstance(recv_obj, OpenSessionReqOutput):
self.session_futures[recv_obj.session_id].set_result( self.session_futures[recv_obj.session_id].set_result(
recv_obj.session_id recv_obj.session_id
......
...@@ -19,7 +19,10 @@ from typing import Optional ...@@ -19,7 +19,10 @@ from typing import Optional
from sglang.srt.configs.model_config import ModelConfig from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
from sglang.srt.managers.io_struct import UpdateWeightFromDiskReqInput from sglang.srt.managers.io_struct import (
GetWeightsByNameReqInput,
UpdateWeightFromDiskReqInput,
)
from sglang.srt.managers.schedule_batch import ModelWorkerBatch, global_server_args_dict from sglang.srt.managers.schedule_batch import ModelWorkerBatch, global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.model_executor.model_runner import ModelRunner
...@@ -160,3 +163,9 @@ class TpModelWorker: ...@@ -160,3 +163,9 @@ class TpModelWorker:
recv_req.model_path, recv_req.load_format recv_req.model_path, recv_req.load_format
) )
return success, message return success, message
def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
parameter = self.model_runner.get_weights_by_name(
recv_req.name, recv_req.truncate_size
)
return parameter
...@@ -23,7 +23,10 @@ from typing import Optional ...@@ -23,7 +23,10 @@ from typing import Optional
import psutil import psutil
import torch import torch
from sglang.srt.managers.io_struct import UpdateWeightFromDiskReqInput from sglang.srt.managers.io_struct import (
GetWeightsByNameReqInput,
UpdateWeightFromDiskReqInput,
)
from sglang.srt.managers.schedule_batch import ModelWorkerBatch from sglang.srt.managers.schedule_batch import ModelWorkerBatch
from sglang.srt.managers.tp_worker import TpModelWorker from sglang.srt.managers.tp_worker import TpModelWorker
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
...@@ -208,6 +211,9 @@ class TpModelWorkerClient: ...@@ -208,6 +211,9 @@ class TpModelWorkerClient:
success, message = self.worker.update_weights_from_disk(recv_req) success, message = self.worker.update_weights_from_disk(recv_req)
return success, message return success, message
def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
return self.worker.get_weights_by_name(recv_req)
def __delete__(self): def __delete__(self):
self.input_queue.put((None, None)) self.input_queue.put((None, None))
self.copy_queue.put((None, None, None)) self.copy_queue.put((None, None, None))
...@@ -20,13 +20,10 @@ import inspect ...@@ -20,13 +20,10 @@ import inspect
import json import json
import logging import logging
import pkgutil import pkgutil
import time
from functools import lru_cache from functools import lru_cache
from tokenize import tabsize from typing import Optional, Type
from typing import Any, Optional, Type, Union
import torch import torch
import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
from vllm.config import DeviceConfig, LoadConfig from vllm.config import DeviceConfig, LoadConfig
from vllm.config import ModelConfig as VllmModelConfig from vllm.config import ModelConfig as VllmModelConfig
...@@ -403,6 +400,23 @@ class ModelRunner: ...@@ -403,6 +400,23 @@ class ModelRunner:
logger.info("Update weights end.") logger.info("Update weights end.")
return True, "Succeeded to update model weights." return True, "Succeeded to update model weights."
def get_weights_by_name(
self, name: str, truncate_size: int = 100
) -> Optional[torch.Tensor]:
"""Get the weights of the parameter by its name. Similar to `get_parameter` in Hugging Face.
Only used for unit test with an unoptimized performance.
For optimized performance, please use torch.save and torch.load.
"""
# TODO: (chenyang) Add support for Qwen models.
try:
return self.model.get_weights_by_name(
name, truncate_size, tp_size=self.tp_size
)
except Exception as e:
logger.error(f"Error when getting parameter {name}: {e}")
return None
def init_lora_manager(self): def init_lora_manager(self):
self.lora_manager = LoRAManager( self.lora_manager = LoRAManager(
base_model=self.model, base_model=self.model,
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/llama.py#L1 # https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/llama.py#L1
"""Inference-only LLaMA model compatible with HuggingFace weights.""" """Inference-only LLaMA model compatible with HuggingFace weights."""
import logging
from typing import Any, Dict, Iterable, Optional, Tuple from typing import Any, Dict, Iterable, Optional, Tuple
import torch import torch
...@@ -45,6 +46,8 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict ...@@ -45,6 +46,8 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.utils import make_layers from sglang.srt.utils import make_layers
logger = logging.getLogger(__name__)
class LlamaMLP(nn.Module): class LlamaMLP(nn.Module):
def __init__( def __init__(
...@@ -305,6 +308,14 @@ class LlamaForCausalLM(nn.Module): ...@@ -305,6 +308,14 @@ class LlamaForCausalLM(nn.Module):
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
self.stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
(".qkv_proj", ".k_proj", "k"),
(".qkv_proj", ".v_proj", "v"),
(".gate_up_proj", ".gate_proj", 0),
(".gate_up_proj", ".up_proj", 1),
]
@torch.no_grad() @torch.no_grad()
def forward( def forward(
...@@ -349,15 +360,7 @@ class LlamaForCausalLM(nn.Module): ...@@ -349,15 +360,7 @@ class LlamaForCausalLM(nn.Module):
return params_mapping.get(name, name) return params_mapping.get(name, name)
def get_module_name_from_weight_name(self, name): def get_module_name_from_weight_name(self, name):
stacked_params_mapping = [ for param_name, weight_name, shard_id, num_shard in self.stacked_params_mapping:
# (param_name, shard_name, shard_id, num_shard)
("qkv_proj", "q_proj", "q", 3),
("qkv_proj", "k_proj", "k", 3),
("qkv_proj", "v_proj", "v", 3),
("gate_up_proj", "gate_proj", 0, 2),
("gate_up_proj", "up_proj", 1, 2),
]
for param_name, weight_name, shard_id, num_shard in stacked_params_mapping:
if weight_name in name: if weight_name in name:
return ( return (
name.replace(weight_name, param_name)[: -len(".weight")], name.replace(weight_name, param_name)[: -len(".weight")],
...@@ -370,6 +373,7 @@ class LlamaForCausalLM(nn.Module): ...@@ -370,6 +373,7 @@ class LlamaForCausalLM(nn.Module):
return len(params_dict) return len(params_dict)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
embed_tokens_weight = None
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"), (".qkv_proj", ".q_proj", "q"),
...@@ -378,6 +382,7 @@ class LlamaForCausalLM(nn.Module): ...@@ -378,6 +382,7 @@ class LlamaForCausalLM(nn.Module):
(".gate_up_proj", ".gate_proj", 0), (".gate_up_proj", ".gate_proj", 0),
(".gate_up_proj", ".up_proj", 1), (".gate_up_proj", ".up_proj", 1),
] ]
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
load_tie_word_embeddings = ( load_tie_word_embeddings = (
...@@ -425,10 +430,79 @@ class LlamaForCausalLM(nn.Module): ...@@ -425,10 +430,79 @@ class LlamaForCausalLM(nn.Module):
# Tie output embedding layer to input embedding layer, to solve issues where lm_head.weight is missing # Tie output embedding layer to input embedding layer, to solve issues where lm_head.weight is missing
param = self.lm_head.weight param = self.lm_head.weight
weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, embed_tokens_weight) if embed_tokens_weight is not None:
weight_loader(param, embed_tokens_weight)
apply_torchao_config_(self, params_dict, set(["proj.weight"])) apply_torchao_config_(self, params_dict, set(["proj.weight"]))
def get_weights_by_name(
self, name: str, truncate_size: int = 100, tp_size: int = 1
) -> Optional[torch.Tensor]:
"""Get the weights of the parameter by its name. Similar to `get_parameter` in Hugging Face.
Only used for unit test with an unoptimized performance.
For optimized performance, please use torch.save and torch.load.
"""
try:
mapped_name = name
mapped_shard_id = None
for param_name, weight_name, shard_id in self.stacked_params_mapping:
if weight_name in name:
mapped_name = name.replace(weight_name, param_name)
mapped_shard_id = shard_id
break
params_dict = dict(self.named_parameters())
if mapped_name in params_dict:
param = params_dict[mapped_name]
if mapped_shard_id is not None:
if mapped_shard_id in ["q", "k", "v"]:
num_heads = self.config.num_attention_heads // tp_size
num_kv_heads = self.config.num_key_value_heads // tp_size
head_dim = (
self.config.hidden_size // self.config.num_attention_heads
)
if mapped_shard_id == "q":
offset = 0
size = num_heads * head_dim
elif mapped_shard_id == "k":
offset = num_heads * head_dim
size = num_kv_heads * head_dim
elif mapped_shard_id == "v":
offset = (num_heads + num_kv_heads) * head_dim
size = num_kv_heads * head_dim
weight = param.data.narrow(0, offset, size)
elif mapped_shard_id in [0, 1]:
intermediate_size = self.config.intermediate_size
hidden_size = self.config.hidden_size
slice_size = intermediate_size // tp_size
if mapped_shard_id == 0: # gate_proj
offset = 0
size = slice_size
elif mapped_shard_id == 1: # up_proj
offset = slice_size
size = slice_size
weight = param.data.narrow(0, offset, size)
else:
weight = param.data
else:
weight = param.data
if tp_size > 1 and ("o_proj" in name or "down_proj" in name):
gathered_weights = [
torch.zeros_like(weight) for _ in range(tp_size)
]
torch.distributed.all_gather(gathered_weights, weight)
weight = torch.cat(gathered_weights, dim=1)
return weight.cpu().to(torch.float32).numpy().tolist()[:truncate_size]
else:
return None
except Exception as e:
logger.error(
f"Error getting weights by name {name} in LlamaForCausalLM: {e}"
)
return None
class Phi3ForCausalLM(LlamaForCausalLM): class Phi3ForCausalLM(LlamaForCausalLM):
pass pass
......
...@@ -52,6 +52,7 @@ from sglang.srt.managers.io_struct import ( ...@@ -52,6 +52,7 @@ from sglang.srt.managers.io_struct import (
CloseSessionReqInput, CloseSessionReqInput,
EmbeddingReqInput, EmbeddingReqInput,
GenerateReqInput, GenerateReqInput,
GetWeightsByNameReqInput,
OpenSessionReqInput, OpenSessionReqInput,
UpdateWeightFromDiskReqInput, UpdateWeightFromDiskReqInput,
) )
...@@ -210,6 +211,24 @@ async def update_weights_from_disk(obj: UpdateWeightFromDiskReqInput, request: R ...@@ -210,6 +211,24 @@ async def update_weights_from_disk(obj: UpdateWeightFromDiskReqInput, request: R
) )
@app.api_route("/get_weights_by_name", methods=["GET", "POST"])
async def get_weights_by_name(obj: GetWeightsByNameReqInput, request: Request):
"""Get model parameter by name."""
try:
ret = await tokenizer_manager.get_weights_by_name(obj, request)
if ret is None:
return ORJSONResponse(
{"error": {"message": "Get parameter by name failed"}},
status_code=HTTPStatus.BAD_REQUEST,
)
else:
return ORJSONResponse(ret, status_code=200)
except Exception as e:
return ORJSONResponse(
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
)
@app.api_route("/open_session", methods=["GET", "POST"]) @app.api_route("/open_session", methods=["GET", "POST"])
async def open_session(obj: OpenSessionReqInput, request: Request): async def open_session(obj: OpenSessionReqInput, request: Request):
"""Open a session, and return its unique session id.""" """Open a session, and return its unique session id."""
...@@ -269,6 +288,18 @@ async def generate_request(obj: GenerateReqInput, request: Request): ...@@ -269,6 +288,18 @@ async def generate_request(obj: GenerateReqInput, request: Request):
) )
@time_func_latency
async def get_weights_by_name_request(obj: GetWeightsByNameReqInput, request: Request):
"""Handle a get parameter by name request."""
try:
ret = await tokenizer_manager.get_weights_by_name(obj, request)
return ret
except ValueError as e:
return ORJSONResponse(
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
)
@app.api_route("/encode", methods=["POST", "PUT"]) @app.api_route("/encode", methods=["POST", "PUT"])
@time_func_latency @time_func_latency
async def encode_request(obj: EmbeddingReqInput, request: Request): async def encode_request(obj: EmbeddingReqInput, request: Request):
...@@ -938,3 +969,8 @@ class Engine: ...@@ -938,3 +969,8 @@ class Engine:
async def get_server_info(self): async def get_server_info(self):
return await _get_server_info() return await _get_server_info()
def get_weights_by_name(self, name, truncate_size=100):
obj = GetWeightsByNameReqInput(name=name, truncate_size=truncate_size)
loop = asyncio.get_event_loop()
return loop.run_until_complete(get_weights_by_name_request(obj, None))
...@@ -38,6 +38,7 @@ suites = { ...@@ -38,6 +38,7 @@ suites = {
"test_update_weights.py", "test_update_weights.py",
"test_vision_openai_server.py", "test_vision_openai_server.py",
"test_session_control.py", "test_session_control.py",
"test_get_parameter_by_name.py",
], ],
"sampling/penaltylib": glob.glob( "sampling/penaltylib": glob.glob(
"sampling/penaltylib/**/test_*.py", recursive=True "sampling/penaltylib/**/test_*.py", recursive=True
......
import gc
import unittest
import numpy as np
import requests
import torch
from transformers import AutoModelForCausalLM
import sglang as sgl
from sglang.test.test_utils import (
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
popen_launch_server,
)
from sglang.utils import terminate_process
class TestUpdateWeights(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.hf_model = AutoModelForCausalLM.from_pretrained(
cls.model, torch_dtype="bfloat16"
).to("cuda:0")
@classmethod
def tearDownClass(cls):
del cls.hf_model
gc.collect()
torch.cuda.empty_cache()
def init_backend(self, backend, dp, tp):
self.engine = None
self.process = None
self.backend = backend
self.dp = dp
self.tp = tp
if backend == "Engine":
self.engine = sgl.Engine(
model_path=self.model,
random_seed=42,
tp_size=self.tp,
dp_size=self.dp,
mem_fraction_static=0.85,
)
else:
self.process = popen_launch_server(
self.model,
self.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=(
"--tp-size",
str(tp),
"--dp-size",
str(dp),
),
)
def close_engine_and_server(self):
if self.engine:
self.engine.shutdown()
if self.process:
terminate_process(self.process)
def assert_update_weights_all_close(self, param_name, truncate_size):
print(
f"param_name: {param_name}, backend: {self.backend}, dp: {self.dp}, tp: {self.tp}"
)
param = self.hf_model.get_parameter(param_name)[:truncate_size]
param_np = param.cpu().detach().float().numpy()
if self.backend == "Engine":
engine_ret = self.engine.get_weights_by_name(param_name, truncate_size)
engine_ret = self._process_return(engine_ret)
np.testing.assert_allclose(engine_ret, param_np, rtol=1e-5, atol=1e-5)
if self.backend == "Runtime":
runtime_ret = requests.get(
f"{self.base_url}/get_weights_by_name",
json={"name": param_name, "truncate_size": truncate_size},
).json()
runtime_ret = self._process_return(runtime_ret)
np.testing.assert_allclose(runtime_ret, param_np, rtol=1e-5, atol=1e-5)
@staticmethod
def _process_return(ret):
if isinstance(ret, list) and len(ret) == 2:
print(f"running assert_allclose on data parallel")
np.testing.assert_allclose(ret[0], ret[1])
return np.array(ret[0])
return np.array(ret)
def test_update_weights_unexist_model(self):
test_suits = [("Engine", 1, 1), ("Runtime", 1, 1)]
if torch.cuda.device_count() >= 2:
test_suits.append(("Engine", 1, 2))
test_suits.append(("Runtime", 2, 1))
if torch.cuda.device_count() >= 4:
test_suits.extend([("Engine", 2, 2), ("Runtime", 2, 2)])
parameters = [
"model.embed_tokens.weight",
"model.layers.0.input_layernorm.weight",
"model.layers.1.self_attn.q_proj.weight",
"model.layers.2.self_attn.k_proj.weight",
"model.layers.3.self_attn.v_proj.weight",
"model.layers.4.self_attn.o_proj.weight",
"model.layers.5.mlp.gate_proj.weight",
"model.layers.6.mlp.up_proj.weight",
"model.layers.7.mlp.down_proj.weight",
"model.layers.8.post_attention_layernorm.weight",
"model.norm.weight",
"lm_head.weight",
]
for test_suit in test_suits:
self.init_backend(*test_suit)
for param_name in parameters:
self.assert_update_weights_all_close(param_name, 100)
self.close_engine_and_server()
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment