Unverified Commit 82699474 authored by Jerry Zhang's avatar Jerry Zhang Committed by GitHub
Browse files

Small fixes for torchao quant (#2476)

parent 7154b4b1
......@@ -26,11 +26,12 @@ def apply_torchao_config_to_model(
quantize_,
)
from torchao.quantization.observer import PerRow, PerTensor
from torchao.quantization.quant_api import _is_linear
if filter_fn is None:
def filter_fn(module, fqn):
return "proj" in fqn
return _is_linear(module) and "proj" in fqn
if torchao_config == "" or torchao_config is None:
return model
......
......@@ -157,6 +157,10 @@ class ModelRunner:
self.sampler = Sampler()
self.load_model()
apply_torchao_config_to_model(
self.model, global_server_args_dict["torchao_config"]
)
# Apply torch TP if the model supports it
supports_torch_tp = getattr(self.model, "supports_torch_tp", False)
if self.tp_size > 1 and supports_torch_tp:
......@@ -165,10 +169,6 @@ class ModelRunner:
else:
self.torch_tp_applied = False
apply_torchao_config_to_model(
self.model, global_server_args_dict["torchao_config"]
)
# Init memory pool and attention backends
if server_args.lora_paths is not None:
self.init_lora_manager()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment