Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
AutoAWQ
Commits
4e7ada89
Commit
4e7ada89
authored
Jul 03, 2023
by
Abhinav Kulkarni
Browse files
[Minor] Added max-memory command line paramemter
parent
d32095ab
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
12 additions
and
2 deletions
+12
-2
awq/entry.py
awq/entry.py
+12
-2
No files found.
awq/entry.py
View file @
4e7ada89
...
@@ -20,6 +20,12 @@ parser.add_argument('--num_fewshot', type=int, default=0)
...
@@ -20,6 +20,12 @@ parser.add_argument('--num_fewshot', type=int, default=0)
# model config
# model config
parser
.
add_argument
(
'--parallel'
,
action
=
'store_true'
,
parser
.
add_argument
(
'--parallel'
,
action
=
'store_true'
,
help
=
"enable model parallelism"
)
help
=
"enable model parallelism"
)
# max memory to offload larger models to CPU
parser
.
add_argument
(
'--max_memory'
,
type
=
str
,
nargs
=
'*'
,
help
=
"List of device_id:max_memory pairs to be parsed into a dictionary; "
\
+
"Example: 0:10GiB 1:10GiB cpu:20GiB; "
\
+
"mode details here: "
\
+
"https://huggingface.co/docs/accelerate/usage_guides/big_modeling"
)
parser
.
add_argument
(
'--auto_parallel'
,
action
=
'store_true'
,
parser
.
add_argument
(
'--auto_parallel'
,
action
=
'store_true'
,
help
=
"automatically set parallel and batch_size"
)
help
=
"automatically set parallel and batch_size"
)
# quantization config
# quantization config
...
@@ -43,6 +49,9 @@ parser.add_argument('--load_awq', type=str, default=None,
...
@@ -43,6 +49,9 @@ parser.add_argument('--load_awq', type=str, default=None,
help
=
"load the awq search results"
)
help
=
"load the awq search results"
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
max_memory
=
[
v
.
split
(
':'
)
for
v
in
(
args
.
max_memory
or
""
)]
max_memory
=
{(
int
(
k
)
if
k
.
isdigit
()
else
k
):
v
for
k
,
v
in
max_memory
}
if
args
.
auto_parallel
:
if
args
.
auto_parallel
:
gpu_list
=
auto_parallel
(
args
)
gpu_list
=
auto_parallel
(
args
)
...
@@ -115,7 +124,7 @@ def build_model_and_enc(model_path):
...
@@ -115,7 +124,7 @@ def build_model_and_enc(model_path):
else
:
else
:
# Inference with fake quant
# Inference with fake quant
# Init model on CPU:
# Init model on CPU:
kwargs
=
{
"torch_dtype"
:
torch
.
float16
}
kwargs
=
{
"torch_dtype"
:
torch
.
float16
,
"low_cpu_mem_usage"
:
True
}
model
=
AutoModelForCausalLM
.
from_pretrained
(
model
=
AutoModelForCausalLM
.
from_pretrained
(
model_path
,
config
=
config
,
trust_remote_code
=
True
,
**
kwargs
)
model_path
,
config
=
config
,
trust_remote_code
=
True
,
**
kwargs
)
...
@@ -151,8 +160,9 @@ def build_model_and_enc(model_path):
...
@@ -151,8 +160,9 @@ def build_model_and_enc(model_path):
kwargs
=
{
kwargs
=
{
"torch_dtype"
:
torch
.
float16
,
"torch_dtype"
:
torch
.
float16
,
"device_map"
:
"auto"
,
"device_map"
:
"auto"
,
"max_memory"
:
{
0
:
"8GiB"
,
"cpu"
:
"99GiB"
}
}
}
if
len
(
max_memory
):
kwargs
[
"max_memory"
]
=
max_memory
model
=
AutoModelForCausalLM
.
from_pretrained
(
model
=
AutoModelForCausalLM
.
from_pretrained
(
model_path
,
config
=
config
,
state_dict
=
model
.
state_dict
(),
trust_remote_code
=
True
,
**
kwargs
)
model_path
,
config
=
config
,
state_dict
=
model
.
state_dict
(),
trust_remote_code
=
True
,
**
kwargs
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment