Unverified Commit fbe66e1d authored by orellavie1212's avatar orellavie1212 Committed by GitHub
Browse files

added support for quantize on LLM module (#1080)

parent 90979c38
...@@ -38,6 +38,9 @@ class LLM: ...@@ -38,6 +38,9 @@ class LLM:
However, if the `torch_dtype` in the config is `float32`, we will However, if the `torch_dtype` in the config is `float32`, we will
use `float16` instead. use `float16` instead.
seed: The seed to initialize the random number generator for sampling. seed: The seed to initialize the random number generator for sampling.
quantization: The method used to quantize the model weights. Currently,
we support "awq". If None, we assume the model weights are not
quantized and use `dtype` to determine the data type of the weights.
revision: The specific model version to use. It can be a branch name, revision: The specific model version to use. It can be a branch name,
a tag name, or a commit id. a tag name, or a commit id.
""" """
...@@ -51,6 +54,7 @@ class LLM: ...@@ -51,6 +54,7 @@ class LLM:
tensor_parallel_size: int = 1, tensor_parallel_size: int = 1,
dtype: str = "auto", dtype: str = "auto",
seed: int = 0, seed: int = 0,
quantization: Optional[str] = None,
**kwargs, **kwargs,
) -> None: ) -> None:
if "disable_log_stats" not in kwargs: if "disable_log_stats" not in kwargs:
...@@ -63,6 +67,7 @@ class LLM: ...@@ -63,6 +67,7 @@ class LLM:
tensor_parallel_size=tensor_parallel_size, tensor_parallel_size=tensor_parallel_size,
dtype=dtype, dtype=dtype,
seed=seed, seed=seed,
quantization=quantization,
**kwargs, **kwargs,
) )
self.llm_engine = LLMEngine.from_engine_args(engine_args) self.llm_engine = LLMEngine.from_engine_args(engine_args)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment