Commit 6e0bde15 authored by Casper Hansen's avatar Casper Hansen
Browse files

Set AWQ_BATCH_SIZE environment variable

parent a2aa804c
import os
from transformers import AutoConfig from transformers import AutoConfig
from awq.models import * from awq.models import *
from awq.models.base import BaseAWQForCausalLM from awq.models.base import BaseAWQForCausalLM
...@@ -35,7 +36,9 @@ class AutoAWQForCausalLM: ...@@ -35,7 +36,9 @@ class AutoAWQForCausalLM:
@classmethod @classmethod
def from_quantized(self, quant_path, quant_filename, max_new_tokens=None, def from_quantized(self, quant_path, quant_filename, max_new_tokens=None,
device='balanced', trust_remote_code=True, fuse_layers=True) -> BaseAWQForCausalLM: device='balanced', trust_remote_code=True, fuse_layers=True,
batch_size=1) -> BaseAWQForCausalLM:
os.environ["AWQ_BATCH_SIZE"] = str(batch_size)
model_type = check_and_get_model_type(quant_path, trust_remote_code) model_type = check_and_get_model_type(quant_path, trust_remote_code)
return AWQ_CAUSAL_LM_MODEL_MAP[model_type].from_quantized( return AWQ_CAUSAL_LM_MODEL_MAP[model_type].from_quantized(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment