Unverified Commit 2a5f0100 authored by Baizhou Zhang's avatar Baizhou Zhang Committed by GitHub
Browse files

Fix GGuf and add back test_gguf.py (#7067)

parent dbdf76ca
...@@ -546,8 +546,6 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -546,8 +546,6 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
param.shard_id.append(loaded_shard_id) param.shard_id.append(loaded_shard_id)
param.shard_id_map[loaded_shard_id] = len(param.data_container) param.shard_id_map[loaded_shard_id] = len(param.data_container)
param.data_container.append(loaded_weight) param.data_container.append(loaded_weight)
if len(param.data_container) == 2:
self.qweight = param.materialize_nested()
return return
param_data = param.data param_data = param.data
...@@ -961,8 +959,6 @@ class QKVParallelLinear(ColumnParallelLinear): ...@@ -961,8 +959,6 @@ class QKVParallelLinear(ColumnParallelLinear):
param.shard_id.append(loaded_shard_id) param.shard_id.append(loaded_shard_id)
param.shard_id_map[loaded_shard_id] = len(param.data_container) param.shard_id_map[loaded_shard_id] = len(param.data_container)
param.data_container.append(loaded_weight) param.data_container.append(loaded_weight)
if len(param.data_container) == 3:
self.qweight = param.materialize_nested()
return return
param_data = param.data param_data = param.data
......
...@@ -1259,12 +1259,19 @@ class GGUFModelLoader(BaseModelLoader): ...@@ -1259,12 +1259,19 @@ class GGUFModelLoader(BaseModelLoader):
): ):
model_config.hf_config.update({"tie_word_embeddings": True}) model_config.hf_config.update({"tie_word_embeddings": True})
target_device = torch.device(device_config.device)
with set_default_torch_dtype(model_config.dtype): with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device): with target_device:
model = _initialize_model(model_config, self.load_config) model = _initialize_model(model_config, self.load_config)
model.load_weights( model.load_weights(
self._get_weights_iterator(local_model_path, gguf_weights_map) self._get_weights_iterator(local_model_path, gguf_weights_map)
) )
for _, module in model.named_modules():
quant_method = getattr(module, "quant_method", None)
if quant_method is not None:
with device_loading_context(module, target_device):
quant_method.process_weights_after_loading(module)
return model return model
......
...@@ -186,7 +186,7 @@ suites = { ...@@ -186,7 +186,7 @@ suites = {
"vllm_dependency_test": [ "vllm_dependency_test": [
TestFile("test_awq.py"), TestFile("test_awq.py"),
TestFile("test_bnb.py"), TestFile("test_bnb.py"),
# TestFile("test_gguf.py", 78), # TODO: Fix GGuf after updating to torch 2.7 and vllm 0.9 TestFile("test_gguf.py", 78),
TestFile("test_gptqmodel_dynamic.py", 72), TestFile("test_gptqmodel_dynamic.py", 72),
TestFile("test_vllm_dependency.py"), TestFile("test_vllm_dependency.py"),
], ],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment