Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
ba3fb4b8
Unverified
Commit
ba3fb4b8
authored
Jun 16, 2023
by
Arthur
Committed by
GitHub
Jun 16, 2023
Browse files
[`SwitchTransformers`] Fix return values (#24300)
* clean history * remove other changes * fix * fix coipes
parent
0b7b4429
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
23 additions
and
28 deletions
+23
-28
src/transformers/models/gptsan_japanese/modeling_gptsan_japanese.py
...ormers/models/gptsan_japanese/modeling_gptsan_japanese.py
+1
-1
src/transformers/models/switch_transformers/modeling_switch_transformers.py
...odels/switch_transformers/modeling_switch_transformers.py
+22
-27
No files found.
src/transformers/models/gptsan_japanese/modeling_gptsan_japanese.py
View file @
ba3fb4b8
...
...
@@ -1348,7 +1348,7 @@ class GPTSanJapaneseForConditionalGeneration(GPTSanJapanesePreTrainedModel):
total_router_logits
=
[]
total_expert_indexes
=
[]
for
router_output
in
router_outputs
:
if
router_output
[
0
]
is
not
None
:
if
len
(
router_output
[
0
]
.
shape
)
>
1
:
router_logits
,
expert_indexes
=
router_output
total_router_logits
.
append
(
router_logits
)
total_expert_indexes
.
append
(
expert_indexes
)
...
...
src/transformers/models/switch_transformers/modeling_switch_transformers.py
View file @
ba3fb4b8
...
...
@@ -798,7 +798,7 @@ class SwitchTransformersBlock(nn.Module):
if
isinstance
(
hidden_states
,
tuple
):
hidden_states
,
router_tuple
=
hidden_states
else
:
router_tuple
=
(
None
,)
router_tuple
=
(
torch
.
tensor
([
0
])
,)
# clamp inf values to enable fp16 training
if
hidden_states
.
dtype
==
torch
.
float16
and
torch
.
isinf
(
hidden_states
).
any
():
...
...
@@ -1683,50 +1683,45 @@ class SwitchTransformersForConditionalGeneration(SwitchTransformersPreTrainedMod
decoder_z_loss
=
None
decoder_aux_loss
=
None
if
labels
is
not
None
:
loss_fct
=
CrossEntropyLoss
(
ignore_index
=-
100
)
# todo check in the config if router loss enables
if
output_router_logits
:
# Compute the router loss (z_loss + auxiliary loss) for each router in the encoder and decoder
encoder_router_logits
,
encoder_expert_indexes
=
self
.
_unpack_router_logits
(
encoder_outputs
.
router_probs
)
if
output_router_logits
:
# Compute the router loss (z_loss + auxiliary loss) for each router in the encoder and decoder
if
self
.
encoder
.
config
.
encoder_sparse_step
>
1
:
encoder_router_logits
,
encoder_expert_indexes
=
self
.
_unpack_router_logits
(
encoder_outputs
[
-
1
])
encoder_z_loss
=
router_z_loss_func
(
encoder_router_logits
)
encoder_router_probs
=
nn
.
Softmax
(
dim
=-
1
)(
encoder_router_logits
)
encoder_aux_loss
=
load_balancing_loss_func
(
encoder_router_probs
,
encoder_expert_indexes
)
else
:
encoder_z_loss
=
0
encoder_aux_loss
=
0
decoder_router_logits
,
decoder_expert_indexes
=
self
.
_unpack_router_logits
(
decoder_outputs
.
router_probs
)
if
self
.
decoder
.
config
.
decoder_sparse_step
>
1
:
decoder_router_logits
,
decoder_expert_indexes
=
self
.
_unpack_router_logits
(
decoder_outputs
[
-
1
])
decoder_z_loss
=
router_z_loss_func
(
decoder_router_logits
)
decoder_router_probs
=
nn
.
Softmax
(
dim
=-
1
)(
decoder_router_logits
)
decoder_aux_loss
=
load_balancing_loss_func
(
decoder_router_probs
,
decoder_expert_indexes
)
else
:
decoder_z_loss
=
0
decoder_aux_loss
=
0
if
labels
is
not
None
:
loss_fct
=
CrossEntropyLoss
(
ignore_index
=-
100
)
# move labels to correct device to enable PP
labels
=
labels
.
to
(
lm_logits
.
device
)
loss
=
loss_fct
(
lm_logits
.
view
(
-
1
,
lm_logits
.
size
(
-
1
)),
labels
.
view
(
-
1
))
if
output_router_logits
and
labels
is
not
None
:
if
output_router_logits
:
z_loss
=
self
.
router_z_loss_coef
*
(
encoder_z_loss
+
decoder_z_loss
)
aux_loss
=
self
.
router_aux_loss_coef
*
(
encoder_aux_loss
+
decoder_aux_loss
)
loss
=
loss
+
z_loss
+
aux_loss
if
not
return_dict
:
output
=
(
lm_logits
,)
if
output_router_logits
:
# only return the loss if they are not None
output
+=
(
encoder_z_loss
,
encoder_aux_loss
,
decoder_z_loss
,
decoder_aux_loss
,
*
decoder_outputs
[
1
:],
*
encoder_outputs
,
)
else
:
output
+=
(
*
decoder_outputs
[
1
:],
*
encoder_outputs
)
if
output_router_logits
:
output
+=
(
encoder_z_loss
,
encoder_aux_loss
,
decoder_z_loss
,
decoder_aux_loss
)
output
+=
(
*
decoder_outputs
[
1
:],
*
encoder_outputs
)
return
((
loss
,)
+
output
)
if
loss
is
not
None
else
output
return
Seq2SeqMoEOutput
(
loss
=
loss
,
logits
=
lm_logits
,
...
...
@@ -1738,18 +1733,18 @@ class SwitchTransformersForConditionalGeneration(SwitchTransformersPreTrainedMod
decoder_hidden_states
=
decoder_outputs
.
hidden_states
,
decoder_attentions
=
decoder_outputs
.
attentions
,
cross_attentions
=
decoder_outputs
.
cross_attentions
,
decoder_router_logits
=
decoder_outputs
.
router_probs
,
encoder_last_hidden_state
=
encoder_outputs
.
last_hidden_state
,
encoder_hidden_states
=
encoder_outputs
.
hidden_states
,
encoder_attentions
=
encoder_outputs
.
attentions
,
encoder_router_logits
=
encoder_outputs
.
router_probs
,
decoder_router_logits
=
decoder_outputs
.
router_probs
,
)
def
_unpack_router_logits
(
self
,
router_outputs
):
total_router_logits
=
[]
total_expert_indexes
=
[]
for
router_output
in
router_outputs
:
if
router_output
[
0
]
is
not
None
:
if
len
(
router_output
[
0
]
.
shape
)
>
1
:
router_logits
,
expert_indexes
=
router_output
total_router_logits
.
append
(
router_logits
)
total_expert_indexes
.
append
(
expert_indexes
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment