Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
66d1eee6
Unverified
Commit
66d1eee6
authored
Mar 27, 2023
by
кѳѳsнī
Committed by
GitHub
Mar 27, 2023
Browse files
load_in_8bit now respects 'balanced' device maps in multi-gpu environments (#22377)
balanced 8bit memory
parent
8cfc6678
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
20 additions
and
7 deletions
+20
-7
src/transformers/modeling_utils.py
src/transformers/modeling_utils.py
+19
-6
src/transformers/models/llama/modeling_llama.py
src/transformers/models/llama/modeling_llama.py
+1
-1
No files found.
src/transformers/modeling_utils.py
View file @
66d1eee6
...
...
@@ -2542,11 +2542,24 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
)
>=
version
.
parse
(
"0.37.0"
)
if
isinstance
(
device_map
,
str
):
special_dtypes
=
{
name
:
torch
.
float32
for
name
,
_
in
model
.
named_parameters
()
if
any
(
m
in
name
for
m
in
keep_in_fp32_modules
)
}
special_dtypes
=
{}
if
load_in_8bit
:
special_dtypes
.
update
(
{
name
:
torch_dtype
for
name
,
_
in
model
.
named_parameters
()
if
any
(
m
in
name
for
m
in
modules_to_not_convert
)
}
)
special_dtypes
.
update
(
{
name
:
torch
.
float32
for
name
,
_
in
model
.
named_parameters
()
if
any
(
m
in
name
for
m
in
keep_in_fp32_modules
)
}
)
if
model
.
_no_split_modules
is
None
:
raise
ValueError
(
f
"
{
model
.
__class__
.
__name__
}
does not support `device_map='
{
device_map
}
'` yet."
)
no_split_modules
=
model
.
_no_split_modules
...
...
@@ -2569,7 +2582,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
if
device_map
!=
"sequential"
and
get_balanced_memory
is
not
None
:
max_memory
=
get_balanced_memory
(
model
,
dtype
=
torch_dtype
,
dtype
=
torch_dtype
if
not
load_in_8bit
else
torch
.
int8
,
low_zero
=
(
device_map
==
"balanced_low_0"
),
max_memory
=
max_memory
,
**
kwargs
,
...
...
src/transformers/models/llama/modeling_llama.py
View file @
66d1eee6
...
...
@@ -785,7 +785,7 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
loss_fct
=
CrossEntropyLoss
()
shift_logits
=
shift_logits
.
view
(
-
1
,
self
.
config
.
vocab_size
)
shift_labels
=
shift_labels
.
view
(
-
1
)
# Enable model
/pipeline
parallelism
# Enable model parallelism
shift_labels
=
shift_labels
.
to
(
shift_logits
.
device
)
loss
=
loss_fct
(
shift_logits
,
shift_labels
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment