Unverified Commit 4f419c00 authored by Flex Wang's avatar Flex Wang Committed by GitHub
Browse files

Fix ShardedStateLoader for vllm fp8 quantization (#7708)

parent a3fce56b
...@@ -579,6 +579,10 @@ class ShardedStateLoader(BaseModelLoader): ...@@ -579,6 +579,10 @@ class ShardedStateLoader(BaseModelLoader):
with torch.device(device_config.device): with torch.device(device_config.device):
model = _initialize_model(model_config, self.load_config, model = _initialize_model(model_config, self.load_config,
lora_config, cache_config) lora_config, cache_config)
for _, module in model.named_modules():
quant_method = getattr(module, "quant_method", None)
if quant_method is not None:
quant_method.process_weights_after_loading(module)
rank = get_tensor_model_parallel_rank() rank = get_tensor_model_parallel_rank()
pattern = os.path.join( pattern = os.path.join(
local_model_path, local_model_path,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment