Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
9fea71b4
Unverified
Commit
9fea71b4
authored
May 31, 2023
by
Sylvain Gugger
Committed by
GitHub
May 31, 2023
Browse files
Fix last instances of kbit -> quantized (#23797)
parent
38dbbc26
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
4 additions
and
4 deletions
+4
-4
src/transformers/modeling_utils.py
src/transformers/modeling_utils.py
+2
-2
src/transformers/trainer.py
src/transformers/trainer.py
+2
-2
No files found.
src/transformers/modeling_utils.py
View file @
9fea71b4
...
@@ -2237,7 +2237,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
...
@@ -2237,7 +2237,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# We force the `dtype` to be float16, this is a requirement from `bitsandbytes`
# We force the `dtype` to be float16, this is a requirement from `bitsandbytes`
logger
.
info
(
logger
.
info
(
f
"Overriding torch_dtype=
{
torch_dtype
}
with `torch_dtype=torch.float16` due to "
f
"Overriding torch_dtype=
{
torch_dtype
}
with `torch_dtype=torch.float16` due to "
"requirements of `bitsandbytes` to enable model loading in
mixed k
bit. "
"requirements of `bitsandbytes` to enable model loading in
8-bit or 4-
bit. "
"Pass your own torch_dtype to specify the dtype of the remaining non-linear layers or pass"
"Pass your own torch_dtype to specify the dtype of the remaining non-linear layers or pass"
" torch_dtype=torch.float16 to remove this warning."
" torch_dtype=torch.float16 to remove this warning."
)
)
...
@@ -2683,7 +2683,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
...
@@ -2683,7 +2683,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
)
)
# training in 8-bit is only available in 0.37.0+
# training in 8-bit is only available in 0.37.0+
model
.
_is_
kbit
_training_enabled
=
version
.
parse
(
model
.
_is_
quantized
_training_enabled
=
version
.
parse
(
importlib_metadata
.
version
(
"bitsandbytes"
)
importlib_metadata
.
version
(
"bitsandbytes"
)
)
>=
version
.
parse
(
"0.37.0"
)
)
>=
version
.
parse
(
"0.37.0"
)
...
...
src/transformers/trainer.py
View file @
9fea71b4
...
@@ -403,8 +403,8 @@ class Trainer:
...
@@ -403,8 +403,8 @@ class Trainer:
)
)
# At this stage the model is already loaded
# At this stage the model is already loaded
if
getattr
(
model
,
"is_
loaded_in_kbit
"
,
False
):
if
getattr
(
model
,
"is_
quantized
"
,
False
):
if
getattr
(
model
,
"_is_
kbit
_training_enabled"
,
False
):
if
getattr
(
model
,
"_is_
quantized
_training_enabled"
,
False
):
logger
.
info
(
logger
.
info
(
"The model is loaded in 8-bit precision. To train this model you need to add additional modules"
"The model is loaded in 8-bit precision. To train this model you need to add additional modules"
" inside the model such as adapters using `peft` library and freeze the model weights. Please"
" inside the model such as adapters using `peft` library and freeze the model weights. Please"
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment