Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
evt_fugx1
dcu_megatron
Commits
f098f250
Commit
f098f250
authored
Mar 31, 2025
by
dongcl
Browse files
cross entropy修改
parent
722e38bf
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
45 additions
and
0 deletions
+45
-0
dcu_megatron/core/tensor_parallel/cross_entropy.py
dcu_megatron/core/tensor_parallel/cross_entropy.py
+45
-0
No files found.
dcu_megatron/core/tensor_parallel/cross_entropy.py
0 → 100644
View file @
f098f250
import
torch
from
typing
import
Tuple
class
VocabParallelCrossEntropy
:
"""
Computes the Cross Entropy Loss splitting the Vocab size across tensor parallel
ranks. This implementation is used in both fused and unfused cross entropy implementations
"""
@
staticmethod
def
calculate_predicted_logits
(
vocab_parallel_logits
:
torch
.
Tensor
,
target
:
torch
.
Tensor
,
logits_max
:
torch
.
Tensor
,
vocab_start_index
:
int
,
vocab_end_index
:
int
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""Calculates predicted logits."""
# In-place subtraction reduces memory pressure.
vocab_parallel_logits
-=
logits_max
.
unsqueeze
(
dim
=-
1
)
# Create a mask of valid vocab ids (1 means it needs to be masked).
target_mask
=
(
target
<
vocab_start_index
)
|
(
target
>=
vocab_end_index
)
masked_target
=
target
.
clone
()
-
vocab_start_index
masked_target
*=
~
target_mask
# Get predicted-logits = logits[target].
# For Simplicity, we convert logits to a 2-D tensor with size
# [*, partition-vocab-size] and target to a 1-D tensor of size [*].
partition_vocab_size
=
vocab_parallel_logits
.
size
()[
-
1
]
logits_2d
=
vocab_parallel_logits
.
view
(
-
1
,
partition_vocab_size
)
masked_target_1d
=
masked_target
.
view
(
-
1
)
arange_1d
=
torch
.
arange
(
start
=
0
,
end
=
logits_2d
.
size
()[
0
],
device
=
logits_2d
.
device
)
predicted_logits_1d
=
logits_2d
[
arange_1d
,
masked_target_1d
]
predicted_logits_1d
=
predicted_logits_1d
.
clone
().
contiguous
()
predicted_logits
=
predicted_logits_1d
.
view_as
(
target
)
predicted_logits
*=
~
target_mask
exp_logits
=
vocab_parallel_logits
torch
.
exp
(
vocab_parallel_logits
,
out
=
exp_logits
)
sum_exp_logits
=
exp_logits
.
sum
(
dim
=-
1
)
return
target_mask
,
masked_target_1d
,
predicted_logits
,
sum_exp_logits
,
exp_logits
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment