Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
ae9230af
Unverified
Commit
ae9230af
authored
Feb 28, 2023
by
Younes Belkada
Committed by
GitHub
Feb 28, 2023
Browse files
[`T5`] Fix torchquant issue (#21843)
* fix torchquant issue * add tests
parent
2d506ea4
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
43 additions
and
6 deletions
+43
-6
src/transformers/models/longt5/modeling_longt5.py
src/transformers/models/longt5/modeling_longt5.py
+5
-1
src/transformers/models/mt5/modeling_mt5.py
src/transformers/models/mt5/modeling_mt5.py
+10
-2
src/transformers/models/switch_transformers/modeling_switch_transformers.py
...odels/switch_transformers/modeling_switch_transformers.py
+5
-1
src/transformers/models/t5/modeling_t5.py
src/transformers/models/t5/modeling_t5.py
+10
-2
tests/models/t5/test_modeling_t5.py
tests/models/t5/test_modeling_t5.py
+13
-0
No files found.
src/transformers/models/longt5/modeling_longt5.py
View file @
ae9230af
...
...
@@ -275,7 +275,11 @@ class LongT5DenseActDense(nn.Module):
hidden_states
=
self
.
wi
(
hidden_states
)
hidden_states
=
self
.
act
(
hidden_states
)
hidden_states
=
self
.
dropout
(
hidden_states
)
if
hidden_states
.
dtype
!=
self
.
wo
.
weight
.
dtype
and
self
.
wo
.
weight
.
dtype
!=
torch
.
int8
:
if
(
isinstance
(
self
.
wo
.
weight
,
torch
.
Tensor
)
and
hidden_states
.
dtype
!=
self
.
wo
.
weight
.
dtype
and
self
.
wo
.
weight
.
dtype
!=
torch
.
int8
):
hidden_states
=
hidden_states
.
to
(
self
.
wo
.
weight
.
dtype
)
hidden_states
=
self
.
wo
(
hidden_states
)
return
hidden_states
...
...
src/transformers/models/mt5/modeling_mt5.py
View file @
ae9230af
...
...
@@ -145,7 +145,11 @@ class MT5DenseActDense(nn.Module):
hidden_states
=
self
.
wi
(
hidden_states
)
hidden_states
=
self
.
act
(
hidden_states
)
hidden_states
=
self
.
dropout
(
hidden_states
)
if
hidden_states
.
dtype
!=
self
.
wo
.
weight
.
dtype
and
self
.
wo
.
weight
.
dtype
!=
torch
.
int8
:
if
(
isinstance
(
self
.
wo
.
weight
,
torch
.
Tensor
)
and
hidden_states
.
dtype
!=
self
.
wo
.
weight
.
dtype
and
self
.
wo
.
weight
.
dtype
!=
torch
.
int8
):
hidden_states
=
hidden_states
.
to
(
self
.
wo
.
weight
.
dtype
)
hidden_states
=
self
.
wo
(
hidden_states
)
return
hidden_states
...
...
@@ -170,7 +174,11 @@ class MT5DenseGatedActDense(nn.Module):
# To make 8bit quantization work for google/flan-t5-xxl, self.wo is kept in float32.
# See https://github.com/huggingface/transformers/issues/20287
# we also make sure the weights are not in `int8` in case users will force `_keep_in_fp32_modules` to be `None``
if
hidden_states
.
dtype
!=
self
.
wo
.
weight
.
dtype
and
self
.
wo
.
weight
.
dtype
!=
torch
.
int8
:
if
(
isinstance
(
self
.
wo
.
weight
,
torch
.
Tensor
)
and
hidden_states
.
dtype
!=
self
.
wo
.
weight
.
dtype
and
self
.
wo
.
weight
.
dtype
!=
torch
.
int8
):
hidden_states
=
hidden_states
.
to
(
self
.
wo
.
weight
.
dtype
)
hidden_states
=
self
.
wo
(
hidden_states
)
...
...
src/transformers/models/switch_transformers/modeling_switch_transformers.py
View file @
ae9230af
...
...
@@ -272,7 +272,11 @@ class SwitchTransformersDenseActDense(nn.Module):
hidden_states
=
self
.
wi
(
hidden_states
)
hidden_states
=
self
.
act
(
hidden_states
)
hidden_states
=
self
.
dropout
(
hidden_states
)
if
hidden_states
.
dtype
!=
self
.
wo
.
weight
.
dtype
and
self
.
wo
.
weight
.
dtype
!=
torch
.
int8
:
if
(
isinstance
(
self
.
wo
.
weight
,
torch
.
Tensor
)
and
hidden_states
.
dtype
!=
self
.
wo
.
weight
.
dtype
and
self
.
wo
.
weight
.
dtype
!=
torch
.
int8
):
hidden_states
=
hidden_states
.
to
(
self
.
wo
.
weight
.
dtype
)
hidden_states
=
self
.
wo
(
hidden_states
)
return
hidden_states
...
...
src/transformers/models/t5/modeling_t5.py
View file @
ae9230af
...
...
@@ -288,7 +288,11 @@ class T5DenseActDense(nn.Module):
hidden_states
=
self
.
wi
(
hidden_states
)
hidden_states
=
self
.
act
(
hidden_states
)
hidden_states
=
self
.
dropout
(
hidden_states
)
if
hidden_states
.
dtype
!=
self
.
wo
.
weight
.
dtype
and
self
.
wo
.
weight
.
dtype
!=
torch
.
int8
:
if
(
isinstance
(
self
.
wo
.
weight
,
torch
.
Tensor
)
and
hidden_states
.
dtype
!=
self
.
wo
.
weight
.
dtype
and
self
.
wo
.
weight
.
dtype
!=
torch
.
int8
):
hidden_states
=
hidden_states
.
to
(
self
.
wo
.
weight
.
dtype
)
hidden_states
=
self
.
wo
(
hidden_states
)
return
hidden_states
...
...
@@ -312,7 +316,11 @@ class T5DenseGatedActDense(nn.Module):
# To make 8bit quantization work for google/flan-t5-xxl, self.wo is kept in float32.
# See https://github.com/huggingface/transformers/issues/20287
# we also make sure the weights are not in `int8` in case users will force `_keep_in_fp32_modules` to be `None``
if
hidden_states
.
dtype
!=
self
.
wo
.
weight
.
dtype
and
self
.
wo
.
weight
.
dtype
!=
torch
.
int8
:
if
(
isinstance
(
self
.
wo
.
weight
,
torch
.
Tensor
)
and
hidden_states
.
dtype
!=
self
.
wo
.
weight
.
dtype
and
self
.
wo
.
weight
.
dtype
!=
torch
.
int8
):
hidden_states
=
hidden_states
.
to
(
self
.
wo
.
weight
.
dtype
)
hidden_states
=
self
.
wo
(
hidden_states
)
...
...
tests/models/t5/test_modeling_t5.py
View file @
ae9230af
...
...
@@ -880,6 +880,19 @@ class T5ModelIntegrationTests(unittest.TestCase):
def
tokenizer
(
self
):
return
T5Tokenizer
.
from_pretrained
(
"t5-base"
)
@
slow
def
test_torch_quant
(
self
):
r
"""
Test that a simple `torch.quantization.quantize_dynamic` call works on a T5 model.
"""
model_name
=
"google/flan-t5-small"
tokenizer
=
T5Tokenizer
.
from_pretrained
(
model_name
)
model
=
T5ForConditionalGeneration
.
from_pretrained
(
model_name
)
model
=
torch
.
quantization
.
quantize_dynamic
(
model
,
{
torch
.
nn
.
Linear
},
dtype
=
torch
.
qint8
)
input_text
=
"Answer the following yes/no question by reasoning step-by-step. Can you write a whole Haiku in a single tweet?"
input_ids
=
tokenizer
(
input_text
,
return_tensors
=
"pt"
).
input_ids
_
=
model
.
generate
(
input_ids
)
@
slow
def
test_small_generation
(
self
):
model
=
T5ForConditionalGeneration
.
from_pretrained
(
"t5-small"
).
to
(
torch_device
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment