Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
7c995381
Commit
7c995381
authored
Nov 12, 2022
by
Tri Dao
Browse files
Add fused cross entropy loss
parent
55797f32
Changes
7
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
1194 additions
and
0 deletions
+1194
-0
csrc/xentropy/interface.cpp
csrc/xentropy/interface.cpp
+51
-0
csrc/xentropy/setup.py
csrc/xentropy/setup.py
+131
-0
csrc/xentropy/xentropy_kernel.cu
csrc/xentropy/xentropy_kernel.cu
+754
-0
flash_attn/losses/cross_entropy_apex.py
flash_attn/losses/cross_entropy_apex.py
+51
-0
flash_attn/losses/cross_entropy_parallel.py
flash_attn/losses/cross_entropy_parallel.py
+112
-0
tests/losses/test_cross_entropy_apex.py
tests/losses/test_cross_entropy_apex.py
+39
-0
tests/losses/test_cross_entropy_parallel.py
tests/losses/test_cross_entropy_parallel.py
+56
-0
No files found.
csrc/xentropy/interface.cpp
0 → 100644
View file @
7c995381
#include <torch/extension.h>
// CUDA forward declarations
std
::
vector
<
at
::
Tensor
>
softmax_xentropy_cuda
(
const
at
::
Tensor
&
input
,
const
at
::
Tensor
&
labels
,
const
float
smoothing
);
at
::
Tensor
softmax_xentropy_backward_cuda
(
const
at
::
Tensor
&
grad_loss
,
at
::
Tensor
&
logits
,
const
at
::
Tensor
&
max_log_sum_exp
,
const
at
::
Tensor
&
labels
,
const
float
smoothing
,
const
bool
inplace
);
// C++ interface
#define CHECK_CUDA(x) AT_ASSERTM(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
std
::
vector
<
at
::
Tensor
>
softmax_xentropy_forward
(
const
at
::
Tensor
&
input
,
const
at
::
Tensor
&
labels
,
const
float
smoothing
)
{
CHECK_CUDA
(
input
);
CHECK_INPUT
(
labels
);
return
softmax_xentropy_cuda
(
input
,
labels
,
smoothing
);
}
at
::
Tensor
softmax_xentropy_backward
(
const
at
::
Tensor
&
grad_loss
,
at
::
Tensor
&
logits
,
const
at
::
Tensor
&
max_log_sum_exp
,
const
at
::
Tensor
&
labels
,
const
float
smoothing
,
const
bool
inplace
)
{
CHECK_CUDA
(
grad_loss
);
CHECK_CUDA
(
logits
);
CHECK_INPUT
(
max_log_sum_exp
);
CHECK_INPUT
(
labels
);
return
softmax_xentropy_backward_cuda
(
grad_loss
,
logits
,
max_log_sum_exp
,
labels
,
smoothing
,
inplace
);
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"forward"
,
&
softmax_xentropy_forward
,
"Softmax cross entropy loss with label smoothing forward (CUDA)"
);
m
.
def
(
"backward"
,
&
softmax_xentropy_backward
,
"Softmax cross entropy loss with label smoothing backward (CUDA)"
);
}
csrc/xentropy/setup.py
0 → 100644
View file @
7c995381
# Adapted from https://github.com/NVIDIA/apex/blob/master/setup.py
import
torch
from
torch.utils.cpp_extension
import
BuildExtension
,
CppExtension
,
CUDAExtension
,
CUDA_HOME
from
setuptools
import
setup
,
find_packages
import
subprocess
import
sys
import
warnings
import
os
# ninja build does not work unless include_dirs are abs path
this_dir
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
def
get_cuda_bare_metal_version
(
cuda_dir
):
raw_output
=
subprocess
.
check_output
([
cuda_dir
+
"/bin/nvcc"
,
"-V"
],
universal_newlines
=
True
)
output
=
raw_output
.
split
()
release_idx
=
output
.
index
(
"release"
)
+
1
release
=
output
[
release_idx
].
split
(
"."
)
bare_metal_major
=
release
[
0
]
bare_metal_minor
=
release
[
1
][
0
]
return
raw_output
,
bare_metal_major
,
bare_metal_minor
def
check_cuda_torch_binary_vs_bare_metal
(
cuda_dir
):
raw_output
,
bare_metal_major
,
bare_metal_minor
=
get_cuda_bare_metal_version
(
cuda_dir
)
torch_binary_major
=
torch
.
version
.
cuda
.
split
(
"."
)[
0
]
torch_binary_minor
=
torch
.
version
.
cuda
.
split
(
"."
)[
1
]
print
(
"
\n
Compiling cuda extensions with"
)
print
(
raw_output
+
"from "
+
cuda_dir
+
"/bin
\n
"
)
if
(
bare_metal_major
!=
torch_binary_major
)
or
(
bare_metal_minor
!=
torch_binary_minor
):
raise
RuntimeError
(
"Cuda extensions are being compiled with a version of Cuda that does "
"not match the version used to compile Pytorch binaries. "
"Pytorch binaries were compiled with Cuda {}.
\n
"
.
format
(
torch
.
version
.
cuda
)
+
"In some cases, a minor-version mismatch will not cause later errors: "
"https://github.com/NVIDIA/apex/pull/323#discussion_r287021798. "
"You can try commenting out this check (at your own risk)."
)
def
raise_if_cuda_home_none
(
global_option
:
str
)
->
None
:
if
CUDA_HOME
is
not
None
:
return
raise
RuntimeError
(
f
"
{
global_option
}
was requested, but nvcc was not found. Are you sure your environment has nvcc available? "
"If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, "
"only images whose names contain 'devel' will provide nvcc."
)
def
append_nvcc_threads
(
nvcc_extra_args
):
_
,
bare_metal_major
,
bare_metal_minor
=
get_cuda_bare_metal_version
(
CUDA_HOME
)
if
int
(
bare_metal_major
)
>=
11
and
int
(
bare_metal_minor
)
>=
2
:
return
nvcc_extra_args
+
[
"--threads"
,
"4"
]
return
nvcc_extra_args
if
not
torch
.
cuda
.
is_available
():
# https://github.com/NVIDIA/apex/issues/486
# Extension builds after https://github.com/pytorch/pytorch/pull/23408 attempt to query torch.cuda.get_device_capability(),
# which will fail if you are compiling in an environment without visible GPUs (e.g. during an nvidia-docker build command).
print
(
"
\n
Warning: Torch did not find available GPUs on this system.
\n
"
,
"If your intention is to cross-compile, this is not an error.
\n
"
"By default, Apex will cross-compile for Pascal (compute capabilities 6.0, 6.1, 6.2),
\n
"
"Volta (compute capability 7.0), Turing (compute capability 7.5),
\n
"
"and, if the CUDA version is >= 11.0, Ampere (compute capability 8.0).
\n
"
"If you wish to cross-compile for a single specific architecture,
\n
"
'export TORCH_CUDA_ARCH_LIST="compute capability" before running setup.py.
\n
'
,
)
if
os
.
environ
.
get
(
"TORCH_CUDA_ARCH_LIST"
,
None
)
is
None
:
_
,
bare_metal_major
,
bare_metal_minor
=
get_cuda_bare_metal_version
(
CUDA_HOME
)
if
int
(
bare_metal_major
)
==
11
:
os
.
environ
[
"TORCH_CUDA_ARCH_LIST"
]
=
"6.0;6.1;6.2;7.0;7.5;8.0"
if
int
(
bare_metal_minor
)
>
0
:
os
.
environ
[
"TORCH_CUDA_ARCH_LIST"
]
=
"6.0;6.1;6.2;7.0;7.5;8.0;8.6"
else
:
os
.
environ
[
"TORCH_CUDA_ARCH_LIST"
]
=
"6.0;6.1;6.2;7.0;7.5"
print
(
"
\n\n
torch.__version__ = {}
\n\n
"
.
format
(
torch
.
__version__
))
TORCH_MAJOR
=
int
(
torch
.
__version__
.
split
(
"."
)[
0
])
TORCH_MINOR
=
int
(
torch
.
__version__
.
split
(
"."
)[
1
])
cmdclass
=
{}
ext_modules
=
[]
# Check, if ATen/CUDAGeneratorImpl.h is found, otherwise use ATen/cuda/CUDAGeneratorImpl.h
# See https://github.com/pytorch/pytorch/pull/70650
generator_flag
=
[]
torch_dir
=
torch
.
__path__
[
0
]
if
os
.
path
.
exists
(
os
.
path
.
join
(
torch_dir
,
"include"
,
"ATen"
,
"CUDAGeneratorImpl.h"
)):
generator_flag
=
[
"-DOLD_GENERATOR_PATH"
]
raise_if_cuda_home_none
(
"--xentropy"
)
# Check, if CUDA11 is installed for compute capability 8.0
cc_flag
=
[]
cc_flag
.
append
(
"-gencode"
)
cc_flag
.
append
(
"arch=compute_70,code=sm_70"
)
cc_flag
.
append
(
"-gencode"
)
cc_flag
.
append
(
"arch=compute_80,code=sm_80"
)
ext_modules
.
append
(
CUDAExtension
(
name
=
"xentropy_cuda_lib"
,
sources
=
[
"interface.cpp"
,
"xentropy_kernel.cu"
],
extra_compile_args
=
{
"cxx"
:
[
"-O3"
]
+
generator_flag
,
"nvcc"
:
append_nvcc_threads
(
[
"-O3"
]
+
generator_flag
+
cc_flag
),
},
include_dirs
=
[
this_dir
],
)
)
setup
(
name
=
"xentropy_cuda_lib"
,
version
=
"0.1"
,
description
=
"Cross-entropy loss"
,
ext_modules
=
ext_modules
,
cmdclass
=
{
"build_ext"
:
BuildExtension
}
if
ext_modules
else
{},
)
csrc/xentropy/xentropy_kernel.cu
0 → 100644
View file @
7c995381
This diff is collapsed.
Click to expand it.
flash_attn/losses/cross_entropy_apex.py
0 → 100644
View file @
7c995381
import
torch
import
torch.nn
as
nn
import
xentropy_cuda_lib
# https://github.com/NVIDIA/apex/blob/master/apex/contrib/xentropy/softmax_xentropy.py
class
SoftmaxCrossEntropyLossFn
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
logits
,
labels
,
smoothing
=
0.0
,
padding_idx
=
0
,
inplace_backward
=
False
):
losses
,
max_log_sum_exp
=
xentropy_cuda_lib
.
forward
(
logits
,
labels
,
smoothing
)
losses
.
masked_fill_
(
labels
==
padding_idx
,
0
)
ctx
.
save_for_backward
(
logits
,
max_log_sum_exp
,
labels
)
ctx
.
smoothing
=
smoothing
ctx
.
padding_idx
=
padding_idx
ctx
.
inplace_backward
=
inplace_backward
return
losses
@
staticmethod
def
backward
(
ctx
,
grad_loss
):
logits
,
max_log_sum_exp
,
labels
=
ctx
.
saved_tensors
if
not
grad_loss
.
is_contiguous
():
grad_loss
=
grad_loss
.
contiguous
()
grad_loss
.
masked_fill_
(
labels
==
ctx
.
padding_idx
,
0
)
grad_logits
=
xentropy_cuda_lib
.
backward
(
grad_loss
,
logits
,
max_log_sum_exp
,
labels
,
ctx
.
smoothing
,
ctx
.
inplace_backward
)
return
grad_logits
,
None
,
None
,
None
,
None
class
CrossEntropyLossApex
(
nn
.
Module
):
def
__init__
(
self
,
ignore_index
=-
100
,
reduction
=
'mean'
,
label_smoothing
=
0.0
,
inplace_backward
=
False
):
super
().
__init__
()
if
reduction
not
in
[
'mean'
,
'none'
]:
raise
NotImplementedError
(
"Only support reduction = 'mean' or 'none'"
)
self
.
ignore_index
=
ignore_index
self
.
reduction
=
reduction
self
.
label_smoothing
=
label_smoothing
self
.
inplace_backward
=
inplace_backward
def
forward
(
self
,
input
,
target
):
assert
input
.
is_cuda
and
target
.
is_cuda
# SoftmaxCrossEntropyLoss implicitly casts to float
loss
=
SoftmaxCrossEntropyLossFn
.
apply
(
input
,
target
,
self
.
label_smoothing
,
self
.
ignore_index
,
self
.
inplace_backward
)
if
self
.
reduction
==
'mean'
:
return
loss
.
sum
()
/
(
target
!=
self
.
ignore_index
).
sum
()
else
:
return
loss
flash_attn/losses/cross_entropy_parallel.py
0 → 100644
View file @
7c995381
# Inspired by https://github.com/NVIDIA/apex/blob/master/apex/transformer/tensor_parallel/cross_entropy.py
# But we make it much faster: we compute the local loss and the LSE, and by exchanging the LSE and
# the losses we can get the global loss. There's no need to do it step by step
# (compute local max, exchange, compute exp, compute local sum, exchange, etc.)
import
torch
import
torch.nn
as
nn
import
xentropy_cuda_lib
from
apex.transformer.parallel_state
import
get_tensor_model_parallel_group
from
apex.transformer.parallel_state
import
get_tensor_model_parallel_rank
from
apex.transformer.parallel_state
import
get_tensor_model_parallel_world_size
from
apex.transformer.tensor_parallel.utils
import
VocabUtility
# `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for
# `_all_gather_base` and `_reduce_scatter_base`. They require the most recent
# version of PyTorch. The following 4 lines are for backward comparability with
# older PyTorch.
if
"all_gather_into_tensor"
not
in
dir
(
torch
.
distributed
):
torch
.
distributed
.
all_gather_into_tensor
=
torch
.
distributed
.
_all_gather_base
if
"reduce_scatter_tensor"
not
in
dir
(
torch
.
distributed
):
torch
.
distributed
.
reduce_scatter_tensor
=
torch
.
distributed
.
_reduce_scatter_base
class
SoftmaxCrossEntropyLossParallelFn
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
logits_parallel
,
labels
,
smoothing
=
0.0
,
ignored_index
=-
100
,
inplace_backward
=
False
):
"""
logits_parallel: (batch, vocab_size / world_size)
labels: (batch,)
"""
assert
smoothing
==
0.0
,
'smoothing != 0.0 is not yet implemented, file an issue if you need it'
batch
,
partition_vocab_size
=
logits_parallel
.
shape
assert
labels
.
shape
==
(
batch
,)
rank
=
get_tensor_model_parallel_rank
()
world_size
=
get_tensor_model_parallel_world_size
()
vocab_start_index
,
vocab_end_index
=
VocabUtility
.
vocab_range_from_per_partition_vocab_size
(
partition_vocab_size
,
get_tensor_model_parallel_rank
(),
get_tensor_model_parallel_world_size
()
)
# Create a mask of valid vocab ids (1 means it needs to be masked).
labels_mask
=
(
labels
<
vocab_start_index
)
|
(
labels
>=
vocab_end_index
)
ignored_mask
=
labels
==
ignored_index
labels_local
=
torch
.
where
(
ignored_mask
,
labels
,
labels
-
vocab_start_index
)
masked_labels
=
labels_local
.
clone
()
masked_labels
[
labels_mask
]
=
ignored_index
losses
,
lse_local
=
xentropy_cuda_lib
.
forward
(
logits_parallel
,
masked_labels
,
smoothing
)
assert
lse_local
.
shape
==
(
batch
,)
assert
losses
.
shape
==
(
batch
,)
losses
.
masked_fill_
(
masked_labels
==
ignored_index
,
0
)
if
world_size
>
1
:
lse_allgather
=
torch
.
empty
(
world_size
,
batch
,
dtype
=
lse_local
.
dtype
,
device
=
lse_local
.
device
)
torch
.
distributed
.
all_gather_into_tensor
(
lse_allgather
,
lse_local
.
contiguous
(),
group
=
get_tensor_model_parallel_group
())
lse
=
torch
.
logsumexp
(
lse_allgather
,
dim
=
0
)
torch
.
distributed
.
all_reduce
(
losses
,
op
=
torch
.
distributed
.
ReduceOp
.
SUM
,
group
=
get_tensor_model_parallel_group
())
# The losses are currently lse_local - predicted_logit, we just have to subtract the
# lse_local and add the lse (global).
rank_per_sample
=
labels
//
partition_vocab_size
lse_local
=
lse_allgather
[
rank_per_sample
,
torch
.
arange
(
batch
,
device
=
lse_allgather
.
device
)]
losses
+=
lse
-
lse_local
losses
.
masked_fill_
(
ignored_mask
,
0
)
else
:
lse
=
lse_local
ctx
.
save_for_backward
(
logits_parallel
,
lse
,
labels_local
)
ctx
.
smoothing
=
smoothing
ctx
.
ignored_index
=
ignored_index
ctx
.
inplace_backward
=
inplace_backward
return
losses
@
staticmethod
def
backward
(
ctx
,
grad_loss
):
logits_parallel
,
lse
,
labels
=
ctx
.
saved_tensors
if
not
grad_loss
.
is_contiguous
():
grad_loss
=
grad_loss
.
contiguous
()
grad_loss
.
masked_fill_
(
labels
==
ctx
.
ignored_index
,
0
)
grad_logits
=
xentropy_cuda_lib
.
backward
(
grad_loss
,
logits_parallel
,
lse
,
labels
,
ctx
.
smoothing
,
ctx
.
inplace_backward
)
return
grad_logits
,
None
,
None
,
None
,
None
,
None
class
CrossEntropyLossParallel
(
nn
.
Module
):
def
__init__
(
self
,
ignore_index
=-
100
,
reduction
=
'mean'
,
label_smoothing
=
0.0
,
inplace_backward
=
False
):
super
().
__init__
()
if
reduction
not
in
[
'mean'
,
'none'
]:
raise
NotImplementedError
(
"Only support reduction = 'mean' or 'none'"
)
self
.
ignore_index
=
ignore_index
self
.
reduction
=
reduction
self
.
label_smoothing
=
label_smoothing
self
.
inplace_backward
=
inplace_backward
def
forward
(
self
,
input
,
target
):
assert
input
.
is_cuda
and
target
.
is_cuda
# SoftmaxCrossEntropyLoss implicitly casts to float
loss
=
SoftmaxCrossEntropyLossParallelFn
.
apply
(
input
,
target
,
self
.
label_smoothing
,
self
.
ignore_index
,
self
.
inplace_backward
)
if
self
.
reduction
==
'mean'
:
return
loss
.
sum
()
/
(
target
!=
self
.
ignore_index
).
sum
()
else
:
return
loss
tests/losses/test_cross_entropy_apex.py
0 → 100644
View file @
7c995381
import
math
import
torch
import
torch.nn.functional
as
F
import
pytest
from
einops
import
rearrange
from
src.losses.cross_entropy_apex
import
CrossEntropyLossApex
is_sm8x
=
torch
.
cuda
.
get_device_capability
(
'cuda'
)[
0
]
>=
8
@
pytest
.
mark
.
parametrize
(
'dtype'
,
[
torch
.
float16
,
torch
.
float32
]
+
([
torch
.
bfloat16
]
if
is_sm8x
else
[]))
# @pytest.mark.parametrize('dtype', [torch.float16])
@
pytest
.
mark
.
parametrize
(
'inplace_backward'
,
[
False
,
True
])
# @pytest.mark.parametrize('inplace_backward', [False])
@
pytest
.
mark
.
parametrize
(
'vocab_size'
,
[
50257
])
def
test_cross_entropy_loss_apex
(
vocab_size
,
inplace_backward
,
dtype
):
device
=
'cuda'
rtol
,
atol
=
(
1e-5
,
1e-6
)
if
dtype
==
torch
.
float32
else
(
1e-3
,
1e-4
)
# set seed
torch
.
random
.
manual_seed
(
0
)
batch_size
=
8
seqlen
=
128
x_pt
=
torch
.
randn
(
batch_size
*
seqlen
,
vocab_size
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
x
=
x_pt
.
detach
().
clone
().
requires_grad_
()
y
=
torch
.
randint
(
0
,
vocab_size
,
(
batch_size
*
seqlen
,),
dtype
=
torch
.
long
,
device
=
device
)
y
[
torch
.
randperm
(
batch_size
*
seqlen
)[:
10
]]
=
-
100
model_pt
=
torch
.
nn
.
CrossEntropyLoss
()
model
=
CrossEntropyLossApex
(
inplace_backward
=
inplace_backward
)
out
=
model
(
x
,
y
)
out_pt
=
model_pt
(
x_pt
.
float
(),
y
)
assert
torch
.
allclose
(
out
,
out_pt
,
rtol
=
rtol
,
atol
=
atol
)
g
=
torch
.
randn_like
(
out
)
out_pt
.
backward
(
g
)
out
.
backward
(
g
)
assert
torch
.
allclose
(
x
.
grad
,
x_pt
.
grad
,
rtol
=
rtol
,
atol
=
atol
)
tests/losses/test_cross_entropy_parallel.py
0 → 100644
View file @
7c995381
# Run test with:
# torchrun --no_python --nproc_per_node=8 pytest -q -s tests/losses/test_cross_entropy_parallel.py
import
math
import
torch
import
torch.nn.functional
as
F
import
pytest
from
apex.transformer
import
parallel_state
from
apex.transformer
import
tensor_parallel
from
src.losses.cross_entropy_parallel
import
CrossEntropyLossParallel
is_sm8x
=
torch
.
cuda
.
get_device_capability
(
'cuda'
)[
0
]
>=
8
@
pytest
.
mark
.
parametrize
(
'dtype'
,
[
torch
.
float16
,
torch
.
float32
]
+
([
torch
.
bfloat16
]
if
is_sm8x
else
[]))
# @pytest.mark.parametrize('dtype', [torch.bfloat16])
@
pytest
.
mark
.
parametrize
(
'inplace_backward'
,
[
False
,
True
])
# @pytest.mark.parametrize('inplace_backward', [False])
@
pytest
.
mark
.
parametrize
(
'vocab_size'
,
[
50264
])
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
1
,
2
,
4
,
8
])
# @pytest.mark.parametrize('world_size', [2])
def
test_cross_entropy_loss_apex
(
vocab_size
,
world_size
,
inplace_backward
,
dtype
):
assert
vocab_size
%
world_size
==
0
rtol
,
atol
=
((
1e-5
,
1e-6
)
if
dtype
==
torch
.
float32
else
((
1e-3
,
1e-4
)
if
dtype
==
torch
.
float16
else
(
1e-2
,
3e-3
)))
if
not
torch
.
distributed
.
is_initialized
():
torch
.
distributed
.
init_process_group
(
backend
=
'nccl'
,
init_method
=
'env://'
)
partition_vocab_size
=
vocab_size
//
world_size
device
=
f
'cuda:
{
torch
.
distributed
.
get_rank
()
}
'
assert
world_size
<=
torch
.
distributed
.
get_world_size
()
parallel_state
.
initialize_model_parallel
(
tensor_model_parallel_size_
=
world_size
)
rank
=
parallel_state
.
get_tensor_model_parallel_rank
()
# set seed
torch
.
random
.
manual_seed
(
0
)
batch_size
=
8
seqlen
=
128
x_pt
=
(
torch
.
randn
(
batch_size
*
seqlen
,
vocab_size
,
device
=
device
,
dtype
=
dtype
)
*
10
).
requires_grad_
()
x
=
tensor_parallel
.
scatter_to_tensor_model_parallel_region
(
x_pt
).
detach
().
clone
().
requires_grad_
()
y
=
torch
.
randint
(
0
,
vocab_size
,
(
batch_size
*
seqlen
,),
dtype
=
torch
.
long
,
device
=
device
)
y
[
torch
.
randperm
(
batch_size
*
seqlen
)[:
10
]]
=
-
100
model_pt
=
torch
.
nn
.
CrossEntropyLoss
(
reduction
=
'none'
)
model
=
CrossEntropyLossParallel
(
reduction
=
'none'
,
inplace_backward
=
inplace_backward
)
out
=
model
(
x
,
y
)
out_pt
=
model_pt
(
x_pt
.
float
(),
y
)
assert
torch
.
allclose
(
out
,
out_pt
,
rtol
=
1e-5
,
atol
=
1e-6
)
g
=
torch
.
randn_like
(
out
)
out_pt
.
backward
(
g
)
out
.
backward
(
g
)
assert
torch
.
allclose
(
x
.
grad
,
x_pt
.
grad
[:,
(
rank
*
partition_vocab_size
):(
rank
+
1
)
*
partition_vocab_size
],
rtol
=
rtol
,
atol
=
atol
)
parallel_state
.
destroy_model_parallel
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment