Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
ba3fb4b8
Unverified
Commit
ba3fb4b8
authored
Jun 16, 2023
by
Arthur
Committed by
GitHub
Jun 16, 2023
Browse files
[`SwitchTransformers`] Fix return values (#24300)
* clean history * remove other changes * fix * fix coipes
parent
0b7b4429
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
23 additions
and
28 deletions
+23
-28
src/transformers/models/gptsan_japanese/modeling_gptsan_japanese.py
...ormers/models/gptsan_japanese/modeling_gptsan_japanese.py
+1
-1
src/transformers/models/switch_transformers/modeling_switch_transformers.py
...odels/switch_transformers/modeling_switch_transformers.py
+22
-27
No files found.
src/transformers/models/gptsan_japanese/modeling_gptsan_japanese.py
View file @
ba3fb4b8
...
...
@@ -1348,7 +1348,7 @@ class GPTSanJapaneseForConditionalGeneration(GPTSanJapanesePreTrainedModel):
total_router_logits
=
[]
total_expert_indexes
=
[]
for
router_output
in
router_outputs
:
if
router_output
[
0
]
is
not
None
:
if
len
(
router_output
[
0
]
.
shape
)
>
1
:
router_logits
,
expert_indexes
=
router_output
total_router_logits
.
append
(
router_logits
)
total_expert_indexes
.
append
(
expert_indexes
)
...
...
src/transformers/models/switch_transformers/modeling_switch_transformers.py
View file @
ba3fb4b8
...
...
@@ -798,7 +798,7 @@ class SwitchTransformersBlock(nn.Module):
if
isinstance
(
hidden_states
,
tuple
):
hidden_states
,
router_tuple
=
hidden_states
else
:
router_tuple
=
(
None
,)
router_tuple
=
(
torch
.
tensor
([
0
])
,)
# clamp inf values to enable fp16 training
if
hidden_states
.
dtype
==
torch
.
float16
and
torch
.
isinf
(
hidden_states
).
any
():
...
...
@@ -1683,50 +1683,45 @@ class SwitchTransformersForConditionalGeneration(SwitchTransformersPreTrainedMod
decoder_z_loss
=
None
decoder_aux_loss
=
None
if
labels
is
not
None
:
loss_fct
=
CrossEntropyLoss
(
ignore_index
=-
100
)
# todo check in the config if router loss enables
if
output_router_logits
:
# Compute the router loss (z_loss + auxiliary loss) for each router in the encoder and decoder
encoder_router_logits
,
encoder_expert_indexes
=
self
.
_unpack_router_logits
(
encoder_outputs
.
router_probs
)
if
output_router_logits
:
# Compute the router loss (z_loss + auxiliary loss) for each router in the encoder and decoder
if
self
.
encoder
.
config
.
encoder_sparse_step
>
1
:
encoder_router_logits
,
encoder_expert_indexes
=
self
.
_unpack_router_logits
(
encoder_outputs
[
-
1
])
encoder_z_loss
=
router_z_loss_func
(
encoder_router_logits
)
encoder_router_probs
=
nn
.
Softmax
(
dim
=-
1
)(
encoder_router_logits
)
encoder_aux_loss
=
load_balancing_loss_func
(
encoder_router_probs
,
encoder_expert_indexes
)
else
:
encoder_z_loss
=
0
encoder_aux_loss
=
0
decoder_router_logits
,
decoder_expert_indexes
=
self
.
_unpack_router_logits
(
decoder_outputs
.
router_probs
)
if
self
.
decoder
.
config
.
decoder_sparse_step
>
1
:
decoder_router_logits
,
decoder_expert_indexes
=
self
.
_unpack_router_logits
(
decoder_outputs
[
-
1
])
decoder_z_loss
=
router_z_loss_func
(
decoder_router_logits
)
decoder_router_probs
=
nn
.
Softmax
(
dim
=-
1
)(
decoder_router_logits
)
decoder_aux_loss
=
load_balancing_loss_func
(
decoder_router_probs
,
decoder_expert_indexes
)
else
:
decoder_z_loss
=
0
decoder_aux_loss
=
0
if
labels
is
not
None
:
loss_fct
=
CrossEntropyLoss
(
ignore_index
=-
100
)
# move labels to correct device to enable PP
labels
=
labels
.
to
(
lm_logits
.
device
)
loss
=
loss_fct
(
lm_logits
.
view
(
-
1
,
lm_logits
.
size
(
-
1
)),
labels
.
view
(
-
1
))
if
output_router_logits
and
labels
is
not
None
:
if
output_router_logits
:
z_loss
=
self
.
router_z_loss_coef
*
(
encoder_z_loss
+
decoder_z_loss
)
aux_loss
=
self
.
router_aux_loss_coef
*
(
encoder_aux_loss
+
decoder_aux_loss
)
loss
=
loss
+
z_loss
+
aux_loss
if
not
return_dict
:
output
=
(
lm_logits
,)
if
output_router_logits
:
# only return the loss if they are not None
output
+=
(
encoder_z_loss
,
encoder_aux_loss
,
decoder_z_loss
,
decoder_aux_loss
,
*
decoder_outputs
[
1
:],
*
encoder_outputs
,
)
else
:
output
+=
(
*
decoder_outputs
[
1
:],
*
encoder_outputs
)
if
output_router_logits
:
output
+=
(
encoder_z_loss
,
encoder_aux_loss
,
decoder_z_loss
,
decoder_aux_loss
)
output
+=
(
*
decoder_outputs
[
1
:],
*
encoder_outputs
)
return
((
loss
,)
+
output
)
if
loss
is
not
None
else
output
return
Seq2SeqMoEOutput
(
loss
=
loss
,
logits
=
lm_logits
,
...
...
@@ -1738,18 +1733,18 @@ class SwitchTransformersForConditionalGeneration(SwitchTransformersPreTrainedMod
decoder_hidden_states
=
decoder_outputs
.
hidden_states
,
decoder_attentions
=
decoder_outputs
.
attentions
,
cross_attentions
=
decoder_outputs
.
cross_attentions
,
decoder_router_logits
=
decoder_outputs
.
router_probs
,
encoder_last_hidden_state
=
encoder_outputs
.
last_hidden_state
,
encoder_hidden_states
=
encoder_outputs
.
hidden_states
,
encoder_attentions
=
encoder_outputs
.
attentions
,
encoder_router_logits
=
encoder_outputs
.
router_probs
,
decoder_router_logits
=
decoder_outputs
.
router_probs
,
)
def
_unpack_router_logits
(
self
,
router_outputs
):
total_router_logits
=
[]
total_expert_indexes
=
[]
for
router_output
in
router_outputs
:
if
router_output
[
0
]
is
not
None
:
if
len
(
router_output
[
0
]
.
shape
)
>
1
:
router_logits
,
expert_indexes
=
router_output
total_router_logits
.
append
(
router_logits
)
total_expert_indexes
.
append
(
expert_indexes
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment