Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
ae9230af
Unverified
Commit
ae9230af
authored
Feb 28, 2023
by
Younes Belkada
Committed by
GitHub
Feb 28, 2023
Browse files
[`T5`] Fix torchquant issue (#21843)
* fix torchquant issue * add tests
parent
2d506ea4
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
43 additions
and
6 deletions
+43
-6
src/transformers/models/longt5/modeling_longt5.py
src/transformers/models/longt5/modeling_longt5.py
+5
-1
src/transformers/models/mt5/modeling_mt5.py
src/transformers/models/mt5/modeling_mt5.py
+10
-2
src/transformers/models/switch_transformers/modeling_switch_transformers.py
...odels/switch_transformers/modeling_switch_transformers.py
+5
-1
src/transformers/models/t5/modeling_t5.py
src/transformers/models/t5/modeling_t5.py
+10
-2
tests/models/t5/test_modeling_t5.py
tests/models/t5/test_modeling_t5.py
+13
-0
No files found.
src/transformers/models/longt5/modeling_longt5.py
View file @
ae9230af
...
...
@@ -275,7 +275,11 @@ class LongT5DenseActDense(nn.Module):
hidden_states
=
self
.
wi
(
hidden_states
)
hidden_states
=
self
.
act
(
hidden_states
)
hidden_states
=
self
.
dropout
(
hidden_states
)
if
hidden_states
.
dtype
!=
self
.
wo
.
weight
.
dtype
and
self
.
wo
.
weight
.
dtype
!=
torch
.
int8
:
if
(
isinstance
(
self
.
wo
.
weight
,
torch
.
Tensor
)
and
hidden_states
.
dtype
!=
self
.
wo
.
weight
.
dtype
and
self
.
wo
.
weight
.
dtype
!=
torch
.
int8
):
hidden_states
=
hidden_states
.
to
(
self
.
wo
.
weight
.
dtype
)
hidden_states
=
self
.
wo
(
hidden_states
)
return
hidden_states
...
...
src/transformers/models/mt5/modeling_mt5.py
View file @
ae9230af
...
...
@@ -145,7 +145,11 @@ class MT5DenseActDense(nn.Module):
hidden_states
=
self
.
wi
(
hidden_states
)
hidden_states
=
self
.
act
(
hidden_states
)
hidden_states
=
self
.
dropout
(
hidden_states
)
if
hidden_states
.
dtype
!=
self
.
wo
.
weight
.
dtype
and
self
.
wo
.
weight
.
dtype
!=
torch
.
int8
:
if
(
isinstance
(
self
.
wo
.
weight
,
torch
.
Tensor
)
and
hidden_states
.
dtype
!=
self
.
wo
.
weight
.
dtype
and
self
.
wo
.
weight
.
dtype
!=
torch
.
int8
):
hidden_states
=
hidden_states
.
to
(
self
.
wo
.
weight
.
dtype
)
hidden_states
=
self
.
wo
(
hidden_states
)
return
hidden_states
...
...
@@ -170,7 +174,11 @@ class MT5DenseGatedActDense(nn.Module):
# To make 8bit quantization work for google/flan-t5-xxl, self.wo is kept in float32.
# See https://github.com/huggingface/transformers/issues/20287
# we also make sure the weights are not in `int8` in case users will force `_keep_in_fp32_modules` to be `None``
if
hidden_states
.
dtype
!=
self
.
wo
.
weight
.
dtype
and
self
.
wo
.
weight
.
dtype
!=
torch
.
int8
:
if
(
isinstance
(
self
.
wo
.
weight
,
torch
.
Tensor
)
and
hidden_states
.
dtype
!=
self
.
wo
.
weight
.
dtype
and
self
.
wo
.
weight
.
dtype
!=
torch
.
int8
):
hidden_states
=
hidden_states
.
to
(
self
.
wo
.
weight
.
dtype
)
hidden_states
=
self
.
wo
(
hidden_states
)
...
...
src/transformers/models/switch_transformers/modeling_switch_transformers.py
View file @
ae9230af
...
...
@@ -272,7 +272,11 @@ class SwitchTransformersDenseActDense(nn.Module):
hidden_states
=
self
.
wi
(
hidden_states
)
hidden_states
=
self
.
act
(
hidden_states
)
hidden_states
=
self
.
dropout
(
hidden_states
)
if
hidden_states
.
dtype
!=
self
.
wo
.
weight
.
dtype
and
self
.
wo
.
weight
.
dtype
!=
torch
.
int8
:
if
(
isinstance
(
self
.
wo
.
weight
,
torch
.
Tensor
)
and
hidden_states
.
dtype
!=
self
.
wo
.
weight
.
dtype
and
self
.
wo
.
weight
.
dtype
!=
torch
.
int8
):
hidden_states
=
hidden_states
.
to
(
self
.
wo
.
weight
.
dtype
)
hidden_states
=
self
.
wo
(
hidden_states
)
return
hidden_states
...
...
src/transformers/models/t5/modeling_t5.py
View file @
ae9230af
...
...
@@ -288,7 +288,11 @@ class T5DenseActDense(nn.Module):
hidden_states
=
self
.
wi
(
hidden_states
)
hidden_states
=
self
.
act
(
hidden_states
)
hidden_states
=
self
.
dropout
(
hidden_states
)
if
hidden_states
.
dtype
!=
self
.
wo
.
weight
.
dtype
and
self
.
wo
.
weight
.
dtype
!=
torch
.
int8
:
if
(
isinstance
(
self
.
wo
.
weight
,
torch
.
Tensor
)
and
hidden_states
.
dtype
!=
self
.
wo
.
weight
.
dtype
and
self
.
wo
.
weight
.
dtype
!=
torch
.
int8
):
hidden_states
=
hidden_states
.
to
(
self
.
wo
.
weight
.
dtype
)
hidden_states
=
self
.
wo
(
hidden_states
)
return
hidden_states
...
...
@@ -312,7 +316,11 @@ class T5DenseGatedActDense(nn.Module):
# To make 8bit quantization work for google/flan-t5-xxl, self.wo is kept in float32.
# See https://github.com/huggingface/transformers/issues/20287
# we also make sure the weights are not in `int8` in case users will force `_keep_in_fp32_modules` to be `None``
if
hidden_states
.
dtype
!=
self
.
wo
.
weight
.
dtype
and
self
.
wo
.
weight
.
dtype
!=
torch
.
int8
:
if
(
isinstance
(
self
.
wo
.
weight
,
torch
.
Tensor
)
and
hidden_states
.
dtype
!=
self
.
wo
.
weight
.
dtype
and
self
.
wo
.
weight
.
dtype
!=
torch
.
int8
):
hidden_states
=
hidden_states
.
to
(
self
.
wo
.
weight
.
dtype
)
hidden_states
=
self
.
wo
(
hidden_states
)
...
...
tests/models/t5/test_modeling_t5.py
View file @
ae9230af
...
...
@@ -880,6 +880,19 @@ class T5ModelIntegrationTests(unittest.TestCase):
def
tokenizer
(
self
):
return
T5Tokenizer
.
from_pretrained
(
"t5-base"
)
@
slow
def
test_torch_quant
(
self
):
r
"""
Test that a simple `torch.quantization.quantize_dynamic` call works on a T5 model.
"""
model_name
=
"google/flan-t5-small"
tokenizer
=
T5Tokenizer
.
from_pretrained
(
model_name
)
model
=
T5ForConditionalGeneration
.
from_pretrained
(
model_name
)
model
=
torch
.
quantization
.
quantize_dynamic
(
model
,
{
torch
.
nn
.
Linear
},
dtype
=
torch
.
qint8
)
input_text
=
"Answer the following yes/no question by reasoning step-by-step. Can you write a whole Haiku in a single tweet?"
input_ids
=
tokenizer
(
input_text
,
return_tensors
=
"pt"
).
input_ids
_
=
model
.
generate
(
input_ids
)
@
slow
def
test_small_generation
(
self
):
model
=
T5ForConditionalGeneration
.
from_pretrained
(
"t5-small"
).
to
(
torch_device
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment