Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
AutoAWQ
Commits
d9bab50c
Commit
d9bab50c
authored
Aug 22, 2023
by
Casper Hansen
Browse files
Implement loading correct sequence length based on config + custom max_new_tokens
parent
b53a9be2
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
17 additions
and
6 deletions
+17
-6
awq/models/auto.py
awq/models/auto.py
+2
-2
awq/models/base.py
awq/models/base.py
+12
-4
awq/models/llama.py
awq/models/llama.py
+1
-0
awq/models/mpt.py
awq/models/mpt.py
+1
-0
awq/models/opt.py
awq/models/opt.py
+1
-0
No files found.
awq/models/auto.py
View file @
d9bab50c
...
@@ -32,10 +32,10 @@ class AutoAWQForCausalLM:
...
@@ -32,10 +32,10 @@ class AutoAWQForCausalLM:
)
)
@
classmethod
@
classmethod
def
from_quantized
(
self
,
quant_path
,
quant_filename
,
def
from_quantized
(
self
,
quant_path
,
quant_filename
,
max_new_tokens
=
None
,
device
=
'balanced'
,
trust_remote_code
=
True
)
->
BaseAWQForCausalLM
:
device
=
'balanced'
,
trust_remote_code
=
True
)
->
BaseAWQForCausalLM
:
model_type
=
check_and_get_model_type
(
quant_path
,
trust_remote_code
)
model_type
=
check_and_get_model_type
(
quant_path
,
trust_remote_code
)
return
AWQ_CAUSAL_LM_MODEL_MAP
[
model_type
].
from_quantized
(
return
AWQ_CAUSAL_LM_MODEL_MAP
[
model_type
].
from_quantized
(
quant_path
,
model_type
,
quant_filename
,
device
,
trust_remote_code
=
trust_remote_code
quant_path
,
model_type
,
quant_filename
,
max_new_tokens
,
device
,
trust_remote_code
=
trust_remote_code
)
)
\ No newline at end of file
awq/models/base.py
View file @
d9bab50c
...
@@ -239,6 +239,7 @@ class BaseAWQForCausalLM(nn.Module):
...
@@ -239,6 +239,7 @@ class BaseAWQForCausalLM(nn.Module):
model_path
,
model_path
,
model_type
,
model_type
,
model_filename
=
''
,
model_filename
=
''
,
max_new_tokens
=
None
,
device
=
'balanced'
,
device
=
'balanced'
,
torch_dtype
=
torch_dtype
,
torch_dtype
=
torch_dtype
,
trust_remote_code
=
trust_remote_code
,
trust_remote_code
=
trust_remote_code
,
...
@@ -247,7 +248,7 @@ class BaseAWQForCausalLM(nn.Module):
...
@@ -247,7 +248,7 @@ class BaseAWQForCausalLM(nn.Module):
)
)
@
classmethod
@
classmethod
def
from_quantized
(
self
,
model_path
,
model_type
,
model_filename
,
def
from_quantized
(
self
,
model_path
,
model_type
,
model_filename
,
max_new_tokens
=
None
,
device
=
'balanced'
,
torch_dtype
=
torch
.
float16
,
trust_remote_code
=
True
,
device
=
'balanced'
,
torch_dtype
=
torch
.
float16
,
trust_remote_code
=
True
,
safetensors
=
False
,
is_quantized
=
True
):
safetensors
=
False
,
is_quantized
=
True
):
# [STEP 1] Download model if path is not a directory
# [STEP 1] Download model if path is not a directory
...
@@ -263,7 +264,7 @@ class BaseAWQForCausalLM(nn.Module):
...
@@ -263,7 +264,7 @@ class BaseAWQForCausalLM(nn.Module):
# TODO: Better naming, model_filename becomes a directory
# TODO: Better naming, model_filename becomes a directory
model_filename
=
model_path
+
f
'/
{
model_filename
}
'
model_filename
=
model_path
+
f
'/
{
model_filename
}
'
# [STEP 2] Load config
# [STEP 2] Load config
and set sequence length
# TODO: Create BaseAWQConfig class
# TODO: Create BaseAWQConfig class
quant_config_path
=
f
'
{
model_path
}
/quant_config.json'
quant_config_path
=
f
'
{
model_path
}
/quant_config.json'
if
os
.
path
.
exists
(
quant_config_path
):
if
os
.
path
.
exists
(
quant_config_path
):
...
@@ -273,7 +274,14 @@ class BaseAWQForCausalLM(nn.Module):
...
@@ -273,7 +274,14 @@ class BaseAWQForCausalLM(nn.Module):
# Default config that works for most models
# Default config that works for most models
quant_config
=
{
"zero_point"
:
True
,
"q_group_size"
:
128
,
"w_bit"
:
4
}
quant_config
=
{
"zero_point"
:
True
,
"q_group_size"
:
128
,
"w_bit"
:
4
}
# Load model config and set max generation length
if
max_new_tokens
is
None
and
hasattr
(
self
,
'max_new_tokens_key'
):
config
=
AutoConfig
.
from_pretrained
(
model_path
,
trust_remote_code
=
trust_remote_code
)
config
=
AutoConfig
.
from_pretrained
(
model_path
,
trust_remote_code
=
trust_remote_code
)
config
.
max_new_tokens
=
getattr
(
config
,
self
.
max_new_tokens_key
)
else
:
max_new_tokens
=
2048
if
max_new_tokens
is
None
else
max_new_tokens
config
=
AutoConfig
.
from_pretrained
(
model_path
,
trust_remote_code
=
trust_remote_code
)
config
.
max_new_tokens
=
max_new_tokens
# [STEP 3] Load model
# [STEP 3] Load model
with
init_empty_weights
():
with
init_empty_weights
():
...
...
awq/models/llama.py
View file @
d9bab50c
...
@@ -3,6 +3,7 @@ from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaFor
...
@@ -3,6 +3,7 @@ from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaFor
class
LlamaAWQForCausalLM
(
BaseAWQForCausalLM
):
class
LlamaAWQForCausalLM
(
BaseAWQForCausalLM
):
layer_type
=
"LlamaDecoderLayer"
layer_type
=
"LlamaDecoderLayer"
max_new_tokens_key
=
"max_position_embeddings"
@
staticmethod
@
staticmethod
def
get_model_layers
(
model
:
LlamaForCausalLM
):
def
get_model_layers
(
model
:
LlamaForCausalLM
):
...
...
awq/models/mpt.py
View file @
d9bab50c
...
@@ -2,6 +2,7 @@ from .base import BaseAWQForCausalLM
...
@@ -2,6 +2,7 @@ from .base import BaseAWQForCausalLM
class
MptAWQForCausalLM
(
BaseAWQForCausalLM
):
class
MptAWQForCausalLM
(
BaseAWQForCausalLM
):
layer_type
=
"MPTBlock"
layer_type
=
"MPTBlock"
max_new_tokens_key
=
"max_seq_len"
@
staticmethod
@
staticmethod
def
get_model_layers
(
model
):
def
get_model_layers
(
model
):
...
...
awq/models/opt.py
View file @
d9bab50c
...
@@ -3,6 +3,7 @@ from transformers.models.opt.modeling_opt import OPTForCausalLM, OPTDecoderLayer
...
@@ -3,6 +3,7 @@ from transformers.models.opt.modeling_opt import OPTForCausalLM, OPTDecoderLayer
class
OptAWQForCausalLM
(
BaseAWQForCausalLM
):
class
OptAWQForCausalLM
(
BaseAWQForCausalLM
):
layer_type
=
"OPTDecoderLayer"
layer_type
=
"OPTDecoderLayer"
max_new_tokens_key
=
"max_position_embeddings"
@
staticmethod
@
staticmethod
def
get_model_layers
(
model
:
OPTForCausalLM
):
def
get_model_layers
(
model
:
OPTForCausalLM
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment