Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
AutoAWQ
Commits
7c976752
Unverified
Commit
7c976752
authored
Nov 11, 2023
by
Andrey Glushenkov
Committed by
GitHub
Nov 11, 2023
Browse files
Pass arguments to AutoConfig (#97)
Co-authored-by:
Casper
<
casperbh.96@gmail.com
>
parent
c5581b27
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
11 additions
and
7 deletions
+11
-7
awq/models/auto.py
awq/models/auto.py
+3
-2
awq/models/base.py
awq/models/base.py
+8
-5
No files found.
awq/models/auto.py
View file @
7c976752
...
@@ -45,12 +45,13 @@ class AutoAWQForCausalLM:
...
@@ -45,12 +45,13 @@ class AutoAWQForCausalLM:
def
from_quantized
(
self
,
quant_path
,
quant_filename
=
''
,
max_new_tokens
=
None
,
def
from_quantized
(
self
,
quant_path
,
quant_filename
=
''
,
max_new_tokens
=
None
,
trust_remote_code
=
True
,
fuse_layers
=
True
,
trust_remote_code
=
True
,
fuse_layers
=
True
,
batch_size
=
1
,
safetensors
=
True
,
batch_size
=
1
,
safetensors
=
True
,
max_memory
=
None
,
offload_folder
=
None
)
->
BaseAWQForCausalLM
:
max_memory
=
None
,
offload_folder
=
None
,
**
config_kwargs
)
->
BaseAWQForCausalLM
:
os
.
environ
[
"AWQ_BATCH_SIZE"
]
=
str
(
batch_size
)
os
.
environ
[
"AWQ_BATCH_SIZE"
]
=
str
(
batch_size
)
model_type
=
check_and_get_model_type
(
quant_path
,
trust_remote_code
)
model_type
=
check_and_get_model_type
(
quant_path
,
trust_remote_code
)
return
AWQ_CAUSAL_LM_MODEL_MAP
[
model_type
].
from_quantized
(
return
AWQ_CAUSAL_LM_MODEL_MAP
[
model_type
].
from_quantized
(
quant_path
,
model_type
,
quant_filename
,
max_new_tokens
,
trust_remote_code
=
trust_remote_code
,
quant_path
,
model_type
,
quant_filename
,
max_new_tokens
,
trust_remote_code
=
trust_remote_code
,
fuse_layers
=
fuse_layers
,
safetensors
=
safetensors
,
fuse_layers
=
fuse_layers
,
safetensors
=
safetensors
,
max_memory
=
max_memory
,
offload_folder
=
offload_folder
max_memory
=
max_memory
,
offload_folder
=
offload_folder
,
**
config_kwargs
)
)
awq/models/base.py
View file @
7c976752
...
@@ -135,11 +135,13 @@ class BaseAWQForCausalLM(nn.Module):
...
@@ -135,11 +135,13 @@ class BaseAWQForCausalLM(nn.Module):
max_new_tokens
=
None
,
torch_dtype
=
torch
.
float16
,
max_new_tokens
=
None
,
torch_dtype
=
torch
.
float16
,
trust_remote_code
=
True
,
safetensors
=
True
,
is_quantized
=
True
,
trust_remote_code
=
True
,
safetensors
=
True
,
is_quantized
=
True
,
fuse_layers
=
False
,
version
=
'GEMM'
,
fuse_layers
=
False
,
version
=
'GEMM'
,
max_memory
=
None
,
offload_folder
=
None
):
max_memory
=
None
,
offload_folder
=
None
,
**
config_kwargs
):
# [STEP 1-2] Load weights path and configs
# [STEP 1-2] Load weights path and configs
model_weights_path
,
config
,
quant_config
=
self
.
_load_config
(
model_weights_path
,
config
,
quant_config
=
self
.
_load_config
(
self
,
model_path
,
model_filename
,
safetensors
,
version
,
self
,
model_path
,
model_filename
,
safetensors
,
version
,
trust_remote_code
,
max_new_tokens
=
max_new_tokens
trust_remote_code
,
max_new_tokens
=
max_new_tokens
,
**
config_kwargs
)
)
# [STEP 3] Load model
# [STEP 3] Load model
...
@@ -184,7 +186,8 @@ class BaseAWQForCausalLM(nn.Module):
...
@@ -184,7 +186,8 @@ class BaseAWQForCausalLM(nn.Module):
return
self
(
model
,
model_type
,
is_quantized
=
is_quantized
,
quant_config
=
quant_config
)
return
self
(
model
,
model_type
,
is_quantized
=
is_quantized
,
quant_config
=
quant_config
)
def
_load_config
(
self
,
model_path
,
model_filename
,
safetensors
=
True
,
def
_load_config
(
self
,
model_path
,
model_filename
,
safetensors
=
True
,
version
=
"GEMM"
,
trust_remote_code
=
True
,
max_new_tokens
=
4096
):
version
=
"GEMM"
,
trust_remote_code
=
True
,
max_new_tokens
=
4096
,
**
config_kwargs
):
# [STEP 1] Download model if path is not a directory
# [STEP 1] Download model if path is not a directory
if
not
os
.
path
.
isdir
(
model_path
):
if
not
os
.
path
.
isdir
(
model_path
):
ignore_patterns
=
[
"*msgpack*"
,
"*h5*"
,
"optimizer.pt"
]
ignore_patterns
=
[
"*msgpack*"
,
"*h5*"
,
"optimizer.pt"
]
...
@@ -206,11 +209,11 @@ class BaseAWQForCausalLM(nn.Module):
...
@@ -206,11 +209,11 @@ class BaseAWQForCausalLM(nn.Module):
# Load model config and set max generation length
# Load model config and set max generation length
if
max_new_tokens
is
None
and
hasattr
(
self
,
'max_new_tokens_key'
):
if
max_new_tokens
is
None
and
hasattr
(
self
,
'max_new_tokens_key'
):
config
=
AutoConfig
.
from_pretrained
(
model_path
,
trust_remote_code
=
trust_remote_code
)
config
=
AutoConfig
.
from_pretrained
(
model_path
,
trust_remote_code
=
trust_remote_code
,
**
config_kwargs
)
config
.
max_new_tokens
=
getattr
(
config
,
self
.
max_new_tokens_key
)
config
.
max_new_tokens
=
getattr
(
config
,
self
.
max_new_tokens_key
)
else
:
else
:
max_new_tokens
=
2048
if
max_new_tokens
is
None
else
max_new_tokens
max_new_tokens
=
2048
if
max_new_tokens
is
None
else
max_new_tokens
config
=
AutoConfig
.
from_pretrained
(
model_path
,
trust_remote_code
=
trust_remote_code
)
config
=
AutoConfig
.
from_pretrained
(
model_path
,
trust_remote_code
=
trust_remote_code
,
**
config_kwargs
)
config
.
max_new_tokens
=
max_new_tokens
config
.
max_new_tokens
=
max_new_tokens
return
model_weights_path
,
config
,
quant_config
return
model_weights_path
,
config
,
quant_config
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment