Unverified Commit 27b557ae authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Clean up model loader (#1440)

parent 93dffd69
...@@ -415,7 +415,7 @@ class ModelTpServer: ...@@ -415,7 +415,7 @@ class ModelTpServer:
# Truncate prompts that are too long # Truncate prompts that are too long
if len(req.origin_input_ids) >= self.max_req_input_len: if len(req.origin_input_ids) >= self.max_req_input_len:
logger.warn( logger.warning(
"Request length is longer than the KV cache pool size or " "Request length is longer than the KV cache pool size or "
"the max context length. Truncated!!!" "the max context length. Truncated!!!"
) )
...@@ -936,6 +936,8 @@ class ModelTpServer: ...@@ -936,6 +936,8 @@ class ModelTpServer:
if success: if success:
flash_cache_success = self.flush_cache() flash_cache_success = self.flush_cache()
assert flash_cache_success, "Cache flush failed after updating weights" assert flash_cache_success, "Cache flush failed after updating weights"
else:
logger.error(message)
return success, message return success, message
......
...@@ -54,11 +54,9 @@ from sglang.srt.server_args import ServerArgs ...@@ -54,11 +54,9 @@ from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import ( from sglang.srt.utils import (
get_available_gpu_memory, get_available_gpu_memory,
is_generation_model, is_generation_model,
is_llama3_405b_fp8_head_16,
is_multimodal_model, is_multimodal_model,
monkey_patch_vllm_dummy_weight_loader, monkey_patch_vllm_dummy_weight_loader,
monkey_patch_vllm_p2p_access_check, monkey_patch_vllm_p2p_access_check,
monkey_patch_vllm_qvk_linear_loader,
) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -166,10 +164,13 @@ class ModelRunner: ...@@ -166,10 +164,13 @@ class ModelRunner:
return min_per_gpu_memory return min_per_gpu_memory
def load_model(self): def load_model(self):
torch.set_num_threads(1)
logger.info( logger.info(
f"Load weight begin. avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB" f"Load weight begin. avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
) )
# This can reduce thread conflicts and speed up weight loading.
torch.set_num_threads(1)
if torch.cuda.get_device_capability()[0] < 8: if torch.cuda.get_device_capability()[0] < 8:
logger.info( logger.info(
"Compute capability below sm80. Use float16 due to lack of bfloat16 support." "Compute capability below sm80. Use float16 due to lack of bfloat16 support."
...@@ -178,6 +179,7 @@ class ModelRunner: ...@@ -178,6 +179,7 @@ class ModelRunner:
if torch.cuda.get_device_capability()[1] < 5: if torch.cuda.get_device_capability()[1] < 5:
raise RuntimeError("SGLang only supports sm75 and above.") raise RuntimeError("SGLang only supports sm75 and above.")
# Prepare the vllm model config
monkey_patch_vllm_dummy_weight_loader() monkey_patch_vllm_dummy_weight_loader()
self.device_config = DeviceConfig() self.device_config = DeviceConfig()
self.load_config = LoadConfig(load_format=self.server_args.load_format) self.load_config = LoadConfig(load_format=self.server_args.load_format)
...@@ -188,23 +190,16 @@ class ModelRunner: ...@@ -188,23 +190,16 @@ class ModelRunner:
tokenizer_mode=None, tokenizer_mode=None,
trust_remote_code=self.server_args.trust_remote_code, trust_remote_code=self.server_args.trust_remote_code,
dtype=self.server_args.dtype, dtype=self.server_args.dtype,
seed=42, seed=self.server_args.random_seed,
skip_tokenizer_init=True, skip_tokenizer_init=True,
) )
# A temporary hack to fix the num_heads for meta-llama/Meta-Llama-3.1-405B-FP8 checkpoints
# Drop this after Sept, 2024.
if is_llama3_405b_fp8_head_16(self.model_config) and self.tp_size <= 8:
self.model_config.hf_config.num_key_value_heads = 8
self.vllm_model_config.hf_config.num_key_value_heads = 8
monkey_patch_vllm_qvk_linear_loader()
self.dtype = self.vllm_model_config.dtype
if self.model_config.model_override_args is not None: if self.model_config.model_override_args is not None:
self.vllm_model_config.hf_config.update( self.vllm_model_config.hf_config.update(
self.model_config.model_override_args self.model_config.model_override_args
) )
self.dtype = self.vllm_model_config.dtype
# Load the model
self.model = get_model( self.model = get_model(
model_config=self.vllm_model_config, model_config=self.vllm_model_config,
load_config=self.load_config, load_config=self.load_config,
...@@ -255,20 +250,20 @@ class ModelRunner: ...@@ -255,20 +250,20 @@ class ModelRunner:
tokenizer_mode=None, tokenizer_mode=None,
trust_remote_code=self.server_args.trust_remote_code, trust_remote_code=self.server_args.trust_remote_code,
dtype=self.server_args.dtype, dtype=self.server_args.dtype,
seed=42, seed=self.server_args.random_seed,
skip_tokenizer_init=True, skip_tokenizer_init=True,
) )
except Exception as e: except Exception as e:
logger.error(f"Failed to load model config: {e}") message = f"Failed to load model config: {e}."
return False, "Failed to update model weights" return False, message
load_config = LoadConfig(load_format=load_format) load_config = LoadConfig(load_format=load_format)
# Only support vllm DefaultModelLoader for now # Only support vllm DefaultModelLoader for now
loader = get_model_loader(load_config) loader = get_model_loader(load_config)
if not isinstance(loader, DefaultModelLoader): if not isinstance(loader, DefaultModelLoader):
logger.error("Failed to get weights iterator: Unsupported loader") message = f"Failed to get model loader: {loader}."
return False, "Failed to update model weights" return False, message
def get_weight_iter(config): def get_weight_iter(config):
iter = loader._get_weights_iterator( iter = loader._get_weights_iterator(
...@@ -293,14 +288,14 @@ class ModelRunner: ...@@ -293,14 +288,14 @@ class ModelRunner:
try: try:
iter = get_weight_iter(vllm_model_config) iter = get_weight_iter(vllm_model_config)
except Exception as e: except Exception as e:
message = f"Failed to get weights iterator: {e}" message = f"Failed to get weights iterator: {e}."
logger.error(message)
return False, message return False, message
try: try:
model = model_load_weights(self.model, iter) model = model_load_weights(self.model, iter)
except Exception as e: except Exception as e:
message = f"Failed to update weights: {e}. \n Rolling back to original weights" message = (
logger.error(message) f"Failed to update weights: {e}.\nRolling back to original weights."
)
del iter del iter
gc.collect() gc.collect()
iter = get_weight_iter(self.vllm_model_config) iter = get_weight_iter(self.vllm_model_config)
...@@ -315,7 +310,7 @@ class ModelRunner: ...@@ -315,7 +310,7 @@ class ModelRunner:
self.model_config.path = model_path self.model_config.path = model_path
logger.info("Update weights end.") logger.info("Update weights end.")
return True, "Succeeded to update model weights" return True, "Succeeded to update model weights."
def init_lora_manager(self): def init_lora_manager(self):
self.lora_manager = LoRAManager( self.lora_manager = LoRAManager(
......
...@@ -152,7 +152,7 @@ async def flush_cache(): ...@@ -152,7 +152,7 @@ async def flush_cache():
async def update_weights(obj: UpdateWeightReqInput, request: Request): async def update_weights(obj: UpdateWeightReqInput, request: Request):
success, message = await tokenizer_manager.update_weights(obj, request) success, message = await tokenizer_manager.update_weights(obj, request)
content = {"message": message, "success": str(success)} content = {"success": success, "message": message}
if success: if success:
return JSONResponse( return JSONResponse(
content, content,
......
...@@ -187,7 +187,7 @@ def allocate_init_ports( ...@@ -187,7 +187,7 @@ def allocate_init_ports(
cur_port += 1 cur_port += 1
if port is not None and ret_ports[0] != port: if port is not None and ret_ports[0] != port:
logger.warn( logger.warning(
f"WARNING: Port {port} is not available. Use port {ret_ports[0]} instead." f"WARNING: Port {port} is not available. Use port {ret_ports[0]} instead."
) )
...@@ -623,56 +623,7 @@ def set_ulimit(target_soft_limit=65535): ...@@ -623,56 +623,7 @@ def set_ulimit(target_soft_limit=65535):
try: try:
resource.setrlimit(resource_type, (target_soft_limit, current_hard)) resource.setrlimit(resource_type, (target_soft_limit, current_hard))
except ValueError as e: except ValueError as e:
logger.warn(f"Fail to set RLIMIT_NOFILE: {e}") logger.warning(f"Fail to set RLIMIT_NOFILE: {e}")
def is_llama3_405b_fp8_head_16(model_config):
"""Return whether the model is meta-llama/Meta-Llama-3.1-405B-FP8 with 16 kv heads."""
if (
model_config.hf_config.architectures[0] == "LlamaForCausalLM"
and model_config.hf_config.hidden_size == 16384
and model_config.hf_config.intermediate_size == 53248
and model_config.hf_config.num_hidden_layers == 126
and model_config.hf_config.num_key_value_heads == 16
and hasattr(model_config.hf_config, "quantization_config")
and model_config.hf_config.quantization_config["quant_method"] == "fbgemm_fp8"
):
return True
return False
def monkey_patch_vllm_qvk_linear_loader():
"""A temporary hack to fix the num_heads for meta-llama/Meta-Llama-3.1-405B-FP8 checkpoints."""
from vllm.model_executor.layers.linear import QKVParallelLinear
origin_weight_loader = QKVParallelLinear.weight_loader
def get_original_weight(loaded_weight, head_dim):
n_kv_head = loaded_weight.shape[0] // (2 * head_dim)
dim = loaded_weight.shape[1]
for i in range(n_kv_head):
loaded_weight[i * head_dim : (i + 1) * head_dim, :] = loaded_weight[
2 * i * head_dim : (2 * i + 1) * head_dim, :
]
original_kv_weight = loaded_weight[: n_kv_head * head_dim, :]
assert original_kv_weight.shape == (n_kv_head * head_dim, dim)
return original_kv_weight
def weight_loader_srt(
self,
param: Parameter,
loaded_weight: torch.Tensor,
loaded_shard_id: Optional[str] = None,
):
if (
loaded_shard_id in ["k", "v"]
and loaded_weight.shape[0] == self.head_size * self.total_num_kv_heads * 2
):
loaded_weight = get_original_weight(loaded_weight, self.head_size)
origin_weight_loader(self, param, loaded_weight, loaded_shard_id)
setattr(QKVParallelLinear, "weight_loader", weight_loader_srt)
def add_api_key_middleware(app, api_key: str): def add_api_key_middleware(app, api_key: str):
......
...@@ -44,7 +44,6 @@ class TestReplaceWeights(unittest.TestCase): ...@@ -44,7 +44,6 @@ class TestReplaceWeights(unittest.TestCase):
) )
print(json.dumps(response.json())) print(json.dumps(response.json()))
print("=" * 100) print("=" * 100)
# return the "text" in response
text = response.json()["text"] text = response.json()["text"]
return text return text
...@@ -61,7 +60,9 @@ class TestReplaceWeights(unittest.TestCase): ...@@ -61,7 +60,9 @@ class TestReplaceWeights(unittest.TestCase):
"model_path": model_path, "model_path": model_path,
}, },
) )
ret = response.json()
print(json.dumps(response.json())) print(json.dumps(response.json()))
return ret
def test_replace_weights(self): def test_replace_weights(self):
origin_model_path = self.get_model_info() origin_model_path = self.get_model_info()
...@@ -70,7 +71,8 @@ class TestReplaceWeights(unittest.TestCase): ...@@ -70,7 +71,8 @@ class TestReplaceWeights(unittest.TestCase):
# update weights # update weights
new_model_path = "meta-llama/Meta-Llama-3.1-8B" new_model_path = "meta-llama/Meta-Llama-3.1-8B"
self.run_update_weights(new_model_path) ret = self.run_update_weights(new_model_path)
assert ret["success"]
updated_model_path = self.get_model_info() updated_model_path = self.get_model_info()
print(f"updated_model_path: {updated_model_path}") print(f"updated_model_path: {updated_model_path}")
...@@ -81,7 +83,9 @@ class TestReplaceWeights(unittest.TestCase): ...@@ -81,7 +83,9 @@ class TestReplaceWeights(unittest.TestCase):
assert origin_response[:32] != updated_response[:32] assert origin_response[:32] != updated_response[:32]
# update weights back # update weights back
self.run_update_weights(origin_model_path) ret = self.run_update_weights(origin_model_path)
assert ret["success"]
updated_model_path = self.get_model_info() updated_model_path = self.get_model_info()
assert updated_model_path == origin_model_path assert updated_model_path == origin_model_path
...@@ -95,7 +99,8 @@ class TestReplaceWeights(unittest.TestCase): ...@@ -95,7 +99,8 @@ class TestReplaceWeights(unittest.TestCase):
# update weights # update weights
new_model_path = "meta-llama/Meta-Llama-3.1-8B-1" new_model_path = "meta-llama/Meta-Llama-3.1-8B-1"
self.run_update_weights(new_model_path) ret = self.run_update_weights(new_model_path)
assert not ret["success"]
updated_model_path = self.get_model_info() updated_model_path = self.get_model_info()
print(f"updated_model_path: {updated_model_path}") print(f"updated_model_path: {updated_model_path}")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment