Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
7c995381
Commit
7c995381
authored
Nov 12, 2022
by
Tri Dao
Browse files
Add fused cross entropy loss
parent
55797f32
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
1194 additions
and
0 deletions
+1194
-0
csrc/xentropy/interface.cpp
csrc/xentropy/interface.cpp
+51
-0
csrc/xentropy/setup.py
csrc/xentropy/setup.py
+131
-0
csrc/xentropy/xentropy_kernel.cu
csrc/xentropy/xentropy_kernel.cu
+754
-0
flash_attn/losses/cross_entropy_apex.py
flash_attn/losses/cross_entropy_apex.py
+51
-0
flash_attn/losses/cross_entropy_parallel.py
flash_attn/losses/cross_entropy_parallel.py
+112
-0
tests/losses/test_cross_entropy_apex.py
tests/losses/test_cross_entropy_apex.py
+39
-0
tests/losses/test_cross_entropy_parallel.py
tests/losses/test_cross_entropy_parallel.py
+56
-0
No files found.
csrc/xentropy/interface.cpp
0 → 100644
View file @
7c995381
#include <torch/extension.h>
// CUDA forward declarations
std
::
vector
<
at
::
Tensor
>
softmax_xentropy_cuda
(
const
at
::
Tensor
&
input
,
const
at
::
Tensor
&
labels
,
const
float
smoothing
);
at
::
Tensor
softmax_xentropy_backward_cuda
(
const
at
::
Tensor
&
grad_loss
,
at
::
Tensor
&
logits
,
const
at
::
Tensor
&
max_log_sum_exp
,
const
at
::
Tensor
&
labels
,
const
float
smoothing
,
const
bool
inplace
);
// C++ interface
#define CHECK_CUDA(x) AT_ASSERTM(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
std
::
vector
<
at
::
Tensor
>
softmax_xentropy_forward
(
const
at
::
Tensor
&
input
,
const
at
::
Tensor
&
labels
,
const
float
smoothing
)
{
CHECK_CUDA
(
input
);
CHECK_INPUT
(
labels
);
return
softmax_xentropy_cuda
(
input
,
labels
,
smoothing
);
}
at
::
Tensor
softmax_xentropy_backward
(
const
at
::
Tensor
&
grad_loss
,
at
::
Tensor
&
logits
,
const
at
::
Tensor
&
max_log_sum_exp
,
const
at
::
Tensor
&
labels
,
const
float
smoothing
,
const
bool
inplace
)
{
CHECK_CUDA
(
grad_loss
);
CHECK_CUDA
(
logits
);
CHECK_INPUT
(
max_log_sum_exp
);
CHECK_INPUT
(
labels
);
return
softmax_xentropy_backward_cuda
(
grad_loss
,
logits
,
max_log_sum_exp
,
labels
,
smoothing
,
inplace
);
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"forward"
,
&
softmax_xentropy_forward
,
"Softmax cross entropy loss with label smoothing forward (CUDA)"
);
m
.
def
(
"backward"
,
&
softmax_xentropy_backward
,
"Softmax cross entropy loss with label smoothing backward (CUDA)"
);
}
csrc/xentropy/setup.py
0 → 100644
View file @
7c995381
# Adapted from https://github.com/NVIDIA/apex/blob/master/setup.py
import
torch
from
torch.utils.cpp_extension
import
BuildExtension
,
CppExtension
,
CUDAExtension
,
CUDA_HOME
from
setuptools
import
setup
,
find_packages
import
subprocess
import
sys
import
warnings
import
os
# ninja build does not work unless include_dirs are abs path
this_dir
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
def
get_cuda_bare_metal_version
(
cuda_dir
):
raw_output
=
subprocess
.
check_output
([
cuda_dir
+
"/bin/nvcc"
,
"-V"
],
universal_newlines
=
True
)
output
=
raw_output
.
split
()
release_idx
=
output
.
index
(
"release"
)
+
1
release
=
output
[
release_idx
].
split
(
"."
)
bare_metal_major
=
release
[
0
]
bare_metal_minor
=
release
[
1
][
0
]
return
raw_output
,
bare_metal_major
,
bare_metal_minor
def
check_cuda_torch_binary_vs_bare_metal
(
cuda_dir
):
raw_output
,
bare_metal_major
,
bare_metal_minor
=
get_cuda_bare_metal_version
(
cuda_dir
)
torch_binary_major
=
torch
.
version
.
cuda
.
split
(
"."
)[
0
]
torch_binary_minor
=
torch
.
version
.
cuda
.
split
(
"."
)[
1
]
print
(
"
\n
Compiling cuda extensions with"
)
print
(
raw_output
+
"from "
+
cuda_dir
+
"/bin
\n
"
)
if
(
bare_metal_major
!=
torch_binary_major
)
or
(
bare_metal_minor
!=
torch_binary_minor
):
raise
RuntimeError
(
"Cuda extensions are being compiled with a version of Cuda that does "
"not match the version used to compile Pytorch binaries. "
"Pytorch binaries were compiled with Cuda {}.
\n
"
.
format
(
torch
.
version
.
cuda
)
+
"In some cases, a minor-version mismatch will not cause later errors: "
"https://github.com/NVIDIA/apex/pull/323#discussion_r287021798. "
"You can try commenting out this check (at your own risk)."
)
def
raise_if_cuda_home_none
(
global_option
:
str
)
->
None
:
if
CUDA_HOME
is
not
None
:
return
raise
RuntimeError
(
f
"
{
global_option
}
was requested, but nvcc was not found. Are you sure your environment has nvcc available? "
"If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, "
"only images whose names contain 'devel' will provide nvcc."
)
def
append_nvcc_threads
(
nvcc_extra_args
):
_
,
bare_metal_major
,
bare_metal_minor
=
get_cuda_bare_metal_version
(
CUDA_HOME
)
if
int
(
bare_metal_major
)
>=
11
and
int
(
bare_metal_minor
)
>=
2
:
return
nvcc_extra_args
+
[
"--threads"
,
"4"
]
return
nvcc_extra_args
if
not
torch
.
cuda
.
is_available
():
# https://github.com/NVIDIA/apex/issues/486
# Extension builds after https://github.com/pytorch/pytorch/pull/23408 attempt to query torch.cuda.get_device_capability(),
# which will fail if you are compiling in an environment without visible GPUs (e.g. during an nvidia-docker build command).
print
(
"
\n
Warning: Torch did not find available GPUs on this system.
\n
"
,
"If your intention is to cross-compile, this is not an error.
\n
"
"By default, Apex will cross-compile for Pascal (compute capabilities 6.0, 6.1, 6.2),
\n
"
"Volta (compute capability 7.0), Turing (compute capability 7.5),
\n
"
"and, if the CUDA version is >= 11.0, Ampere (compute capability 8.0).
\n
"
"If you wish to cross-compile for a single specific architecture,
\n
"
'export TORCH_CUDA_ARCH_LIST="compute capability" before running setup.py.
\n
'
,
)
if
os
.
environ
.
get
(
"TORCH_CUDA_ARCH_LIST"
,
None
)
is
None
:
_
,
bare_metal_major
,
bare_metal_minor
=
get_cuda_bare_metal_version
(
CUDA_HOME
)
if
int
(
bare_metal_major
)
==
11
:
os
.
environ
[
"TORCH_CUDA_ARCH_LIST"
]
=
"6.0;6.1;6.2;7.0;7.5;8.0"
if
int
(
bare_metal_minor
)
>
0
:
os
.
environ
[
"TORCH_CUDA_ARCH_LIST"
]
=
"6.0;6.1;6.2;7.0;7.5;8.0;8.6"
else
:
os
.
environ
[
"TORCH_CUDA_ARCH_LIST"
]
=
"6.0;6.1;6.2;7.0;7.5"
print
(
"
\n\n
torch.__version__ = {}
\n\n
"
.
format
(
torch
.
__version__
))
TORCH_MAJOR
=
int
(
torch
.
__version__
.
split
(
"."
)[
0
])
TORCH_MINOR
=
int
(
torch
.
__version__
.
split
(
"."
)[
1
])
cmdclass
=
{}
ext_modules
=
[]
# Check, if ATen/CUDAGeneratorImpl.h is found, otherwise use ATen/cuda/CUDAGeneratorImpl.h
# See https://github.com/pytorch/pytorch/pull/70650
generator_flag
=
[]
torch_dir
=
torch
.
__path__
[
0
]
if
os
.
path
.
exists
(
os
.
path
.
join
(
torch_dir
,
"include"
,
"ATen"
,
"CUDAGeneratorImpl.h"
)):
generator_flag
=
[
"-DOLD_GENERATOR_PATH"
]
raise_if_cuda_home_none
(
"--xentropy"
)
# Check, if CUDA11 is installed for compute capability 8.0
cc_flag
=
[]
cc_flag
.
append
(
"-gencode"
)
cc_flag
.
append
(
"arch=compute_70,code=sm_70"
)
cc_flag
.
append
(
"-gencode"
)
cc_flag
.
append
(
"arch=compute_80,code=sm_80"
)
ext_modules
.
append
(
CUDAExtension
(
name
=
"xentropy_cuda_lib"
,
sources
=
[
"interface.cpp"
,
"xentropy_kernel.cu"
],
extra_compile_args
=
{
"cxx"
:
[
"-O3"
]
+
generator_flag
,
"nvcc"
:
append_nvcc_threads
(
[
"-O3"
]
+
generator_flag
+
cc_flag
),
},
include_dirs
=
[
this_dir
],
)
)
setup
(
name
=
"xentropy_cuda_lib"
,
version
=
"0.1"
,
description
=
"Cross-entropy loss"
,
ext_modules
=
ext_modules
,
cmdclass
=
{
"build_ext"
:
BuildExtension
}
if
ext_modules
else
{},
)
csrc/xentropy/xentropy_kernel.cu
0 → 100644
View file @
7c995381
// Adapted from https://github.com/NVIDIA/apex/blob/master/apex/contrib/csrc/xentropy/xentropy_kernel.cu
// TD [2022-09-17]: We make it work for bfloat16, and add an option to do the backward inplace (to save memory).
/**
* From PyTorch:
*
* Copyright (c) 2016- Facebook, Inc (Adam Paszke)
* Copyright (c) 2014- Facebook, Inc (Soumith Chintala)
* Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert)
* Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu)
* Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu)
* Copyright (c) 2011-2013 NYU (Clement Farabet)
* Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston)
* Copyright (c) 2006 Idiap Research Institute (Samy Bengio)
* Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz)
*
* From Caffe2:
*
* Copyright (c) 2016-present, Facebook Inc. All rights reserved.
*
* All contributions by Facebook:
* Copyright (c) 2016 Facebook Inc.
*
* All contributions by Google:
* Copyright (c) 2015 Google Inc.
* All rights reserved.
*
* All contributions by Yangqing Jia:
* Copyright (c) 2015 Yangqing Jia
* All rights reserved.
*
* All contributions from Caffe:
* Copyright(c) 2013, 2014, 2015, the respective contributors
* All rights reserved.
*
* All other contributions:
* Copyright(c) 2015, 2016 the respective contributors
* All rights reserved.
*
* Caffe2 uses a copyright model similar to Caffe: each contributor holds
* copyright over their contributions to Caffe2. The project versioning records
* all such contribution and copyright details. If a contributor wants to further
* mark their specific copyright on a particular contribution, they should
* indicate their copyright solely in the commit message of the change when it is
* committed.
*
* All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
*
* 3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America
* and IDIAP Research Institute nor the names of its contributors may be
* used to endorse or promote products derived from this software without
* specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
* LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
* POSSIBILITY OF SUCH DAMAGE.
*/
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/NumericLimits.cuh>
// https://github.com/NVIDIA/apex/blob/master/csrc/type_shim.h
// #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
#define DISPATCH_FLOAT_AND_HALF_AND_BF16(TYPE, LEVEL, NAME, ...) \
switch(TYPE) \
{ \
case at::ScalarType::Float: \
{ \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: \
{ \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: \
{ \
using scalar_t_##LEVEL = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
// #else
// #define DISPATCH_FLOAT_AND_HALF_AND_BF16(TYPE, LEVEL, NAME, ...) \
// switch(TYPE) \
// { \
// case at::ScalarType::Float: \
// { \
// using scalar_t_##LEVEL = float; \
// __VA_ARGS__; \
// break; \
// } \
// case at::ScalarType::Half: \
// { \
// using scalar_t_##LEVEL = at::Half; \
// __VA_ARGS__; \
// break; \
// } \
// default: \
// AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
// }
// #endif
#define ALIGN_BYTES 16
using
Tensor
=
at
::
Tensor
;
using
TensorList
=
at
::
TensorList
;
using
ScalarType
=
at
::
ScalarType
;
using
at
::
acc_type
;
template
<
typename
T
,
typename
AccumT
,
typename
OutT
>
struct
LogSoftMaxForwardEpilogue
{
__device__
__forceinline__
LogSoftMaxForwardEpilogue
(
AccumT
max_input
,
AccumT
sum
)
:
logsum
(
max_input
+
std
::
log
(
sum
))
{}
__device__
__forceinline__
LogSoftMaxForwardEpilogue
(
AccumT
max_log_sum_exp
)
:
logsum
(
max_log_sum_exp
)
{}
__device__
__forceinline__
OutT
operator
()(
T
input
)
const
{
return
static_cast
<
OutT
>
(
input
-
logsum
);
}
const
AccumT
logsum
;
};
template
<
typename
T
,
typename
AccumT
,
typename
OutT
>
struct
LogSoftMaxBackwardEpilogue
{
__device__
__forceinline__
LogSoftMaxBackwardEpilogue
(
AccumT
sum
)
:
sum
(
sum
)
{}
__device__
__forceinline__
T
operator
()(
OutT
gradOutput
,
OutT
output
)
const
{
return
static_cast
<
T
>
(
gradOutput
-
std
::
exp
(
static_cast
<
AccumT
>
(
output
))
*
sum
);
}
const
AccumT
sum
;
};
const
int
max_threads
=
1024
;
inline
dim3
SoftMax_getBlockSize
(
int
ILP
,
uint64_t
dim_size
)
{
uint64_t
block_size
=
1
;
uint64_t
max_block_size
=
std
::
min
(
dim_size
/
ILP
,
static_cast
<
uint64_t
>
(
max_threads
));
while
(
block_size
<
(
max_block_size
/
2
))
block_size
*=
2
;
// Launch at least a single warp - the kernel assumes that.
block_size
=
std
::
max
(
block_size
,
static_cast
<
uint64_t
>
(
32
));
return
dim3
(
block_size
);
}
template
<
typename
T
>
struct
Add
{
__device__
__forceinline__
T
operator
()(
T
a
,
T
b
)
const
{
return
a
+
b
;
}
};
template
<
typename
T
>
struct
Max
{
__device__
__forceinline__
T
operator
()(
T
a
,
T
b
)
const
{
return
a
<
b
?
b
:
a
;
}
};
////////////////////////////////////////////////////////////////////////////////
// Regular kernel (fast when dim_size is large; requires inner_size == 1)
////////////////////////////////////////////////////////////////////////////////
template
<
typename
T
,
typename
AccumT
>
struct
MaxFloat
{
__device__
__forceinline__
AccumT
operator
()(
AccumT
max
,
T
v
)
const
{
return
::
max
(
max
,
(
AccumT
)
v
);
}
};
template
<
typename
T
,
typename
AccumT
>
struct
AddFloat
{
__device__
__forceinline__
AccumT
operator
()(
AccumT
sum
,
T
v
)
const
{
return
sum
+
v
;
}
};
template
<
typename
T
,
typename
AccumT
>
struct
SumExpFloat
{
__device__
__forceinline__
SumExpFloat
(
AccumT
v
)
:
max_k
(
v
)
{}
__device__
__forceinline__
AccumT
operator
()(
AccumT
sum
,
T
v
)
const
{
return
sum
+
std
::
exp
(
v
-
max_k
);
}
const
AccumT
max_k
;
};
template
<
template
<
typename
>
class
Reduction
,
typename
AccumT
>
__device__
__forceinline__
AccumT
blockReduce
(
AccumT
*
smem
,
AccumT
val
,
const
Reduction
<
AccumT
>&
r
,
AccumT
defaultVal
)
{
// To avoid RaW races from chaining blockReduce calls together, we need a sync here
__syncthreads
();
smem
[
threadIdx
.
x
]
=
val
;
__syncthreads
();
AccumT
warpVal
=
defaultVal
;
// First warp will perform per-warp reductions for the remaining warps
uint32_t
mask
=
(((
uint64_t
)
1
)
<<
(
blockDim
.
x
/
32
))
-
1
;
if
(
threadIdx
.
x
<
32
)
{
int
lane
=
threadIdx
.
x
%
32
;
if
(
lane
<
blockDim
.
x
/
32
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
32
;
++
i
)
{
warpVal
=
r
(
warpVal
,
smem
[
lane
*
32
+
i
]);
}
__syncwarp
(
mask
);
smem
[
lane
]
=
warpVal
;
}
}
__syncthreads
();
// First thread will perform a reduction of the above per-warp reductions
AccumT
blockVal
=
defaultVal
;
if
(
threadIdx
.
x
==
0
)
{
for
(
int
i
=
0
;
i
<
blockDim
.
x
/
32
;
++
i
)
{
blockVal
=
r
(
blockVal
,
smem
[
i
]);
}
smem
[
0
]
=
blockVal
;
}
// Sync and broadcast
__syncthreads
();
return
smem
[
0
];
}
template
<
template
<
typename
>
class
Reduction1
,
template
<
typename
>
class
Reduction2
,
typename
AccumT
>
__device__
__forceinline__
void
blockReduce
(
AccumT
*
smem
,
AccumT
*
reducVal1
,
AccumT
val1
,
const
Reduction1
<
AccumT
>&
r1
,
AccumT
defaultVal1
,
AccumT
*
reducVal2
,
AccumT
val2
,
const
Reduction2
<
AccumT
>&
r2
,
AccumT
defaultVal2
)
{
// To avoid RaW races from chaining blockReduce calls together, we need a sync here
__syncthreads
();
smem
[
threadIdx
.
x
]
=
val1
;
smem
[
blockDim
.
x
+
threadIdx
.
x
]
=
val2
;
__syncthreads
();
AccumT
warpVal1
=
defaultVal1
;
AccumT
warpVal2
=
defaultVal2
;
// First warp will perform per-warp reductions for the remaining warps
uint32_t
mask
=
(((
uint64_t
)
1
)
<<
(
blockDim
.
x
/
32
))
-
1
;
if
(
threadIdx
.
x
<
32
)
{
int
lane
=
threadIdx
.
x
%
32
;
if
(
lane
<
blockDim
.
x
/
32
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
32
;
++
i
)
{
warpVal1
=
r1
(
warpVal1
,
smem
[
lane
*
32
+
i
]);
warpVal2
=
r2
(
warpVal2
,
smem
[
lane
*
32
+
i
+
blockDim
.
x
]);
}
__syncwarp
(
mask
);
smem
[
lane
]
=
warpVal1
;
smem
[
lane
+
blockDim
.
x
]
=
warpVal2
;
}
}
__syncthreads
();
// First thread will perform a reduction of the above per-warp reductions
AccumT
blockVal1
=
defaultVal1
;
AccumT
blockVal2
=
defaultVal2
;
if
(
threadIdx
.
x
==
0
)
{
for
(
int
i
=
0
;
i
<
blockDim
.
x
/
32
;
++
i
)
{
blockVal1
=
r1
(
blockVal1
,
smem
[
i
]);
blockVal2
=
r2
(
blockVal2
,
smem
[
i
+
blockDim
.
x
]);
}
smem
[
0
]
=
blockVal1
;
smem
[
blockDim
.
x
]
=
blockVal2
;
}
// Sync and broadcast
__syncthreads
();
*
reducVal1
=
smem
[
0
];
*
reducVal2
=
smem
[
blockDim
.
x
];
__syncthreads
();
}
template
<
template
<
typename
,
typename
>
class
Reduction
,
int
ILP
,
typename
T
,
typename
AccumT
>
__device__
__forceinline__
AccumT
ilpReduce
(
int
shift
,
T
*
data
,
int
size
,
const
Reduction
<
T
,
AccumT
>&
r
,
AccumT
defaultVal
)
{
typedef
typename
std
::
aligned_storage
<
ILP
*
sizeof
(
T
),
ILP
*
alignof
(
T
)
>::
type
LoadT
;
AccumT
threadVal
=
defaultVal
;
int
offset
=
threadIdx
.
x
;
// shift and do 1
if
(
shift
>
0
){
data
-=
shift
;
size
+=
shift
;
if
(
threadIdx
.
x
>=
shift
){
threadVal
=
r
(
threadVal
,
data
[
offset
]);
}
size
-=
blockDim
.
x
;
data
+=
blockDim
.
x
;
}
int
last
=
size
%
(
ILP
*
blockDim
.
x
);
T
v
[
ILP
];
LoadT
*
value
=
reinterpret_cast
<
LoadT
*>
(
&
v
);
for
(;
offset
*
ILP
<
(
size
-
last
);
offset
+=
blockDim
.
x
)
{
*
value
=
reinterpret_cast
<
LoadT
*>
(
data
)[
offset
];
for
(
int
j
=
0
;
j
<
ILP
;
++
j
)
{
threadVal
=
r
(
threadVal
,
v
[
j
]);
}
}
offset
=
size
-
last
+
threadIdx
.
x
;
// Epilogue
for
(;
offset
<
size
;
offset
+=
blockDim
.
x
)
threadVal
=
r
(
threadVal
,
data
[
offset
]);
return
threadVal
;
}
template
<
template
<
typename
,
typename
>
class
Reduction1
,
template
<
typename
,
typename
>
class
Reduction2
,
int
ILP
,
typename
T
,
typename
AccumT
>
__device__
__forceinline__
void
ilpReduce
(
int
shift
,
T
*
data
,
int
size
,
AccumT
*
reducVal1
,
const
Reduction1
<
T
,
AccumT
>&
r1
,
AccumT
defaultVal1
,
AccumT
*
reducVal2
,
const
Reduction2
<
T
,
AccumT
>&
r2
,
AccumT
defaultVal2
)
{
typedef
typename
std
::
aligned_storage
<
ILP
*
sizeof
(
T
),
ILP
*
alignof
(
T
)
>::
type
LoadT
;
AccumT
threadVal1
=
defaultVal1
;
AccumT
threadVal2
=
defaultVal2
;
int
offset
=
threadIdx
.
x
;
// shift and do 1
if
(
shift
>
0
){
data
-=
shift
;
size
+=
shift
;
if
(
threadIdx
.
x
>=
shift
){
threadVal1
=
r1
(
threadVal1
,
data
[
offset
]);
threadVal2
=
r2
(
threadVal2
,
data
[
offset
]);
}
size
-=
blockDim
.
x
;
data
+=
blockDim
.
x
;
}
int
last
=
size
%
(
ILP
*
blockDim
.
x
);
T
v
[
ILP
];
LoadT
*
value
=
reinterpret_cast
<
LoadT
*>
(
&
v
);
for
(;
offset
*
ILP
<
(
size
-
last
);
offset
+=
blockDim
.
x
)
{
*
value
=
reinterpret_cast
<
LoadT
*>
(
data
)[
offset
];
for
(
int
j
=
0
;
j
<
ILP
;
++
j
)
{
threadVal1
=
r1
(
threadVal1
,
v
[
j
]);
threadVal2
=
r2
(
threadVal2
,
v
[
j
]);
}
}
offset
=
size
-
last
+
threadIdx
.
x
;
// Epilogue
for
(;
offset
<
size
;
offset
+=
blockDim
.
x
)
{
threadVal1
=
r1
(
threadVal1
,
data
[
offset
]);
threadVal2
=
r2
(
threadVal2
,
data
[
offset
]);
}
*
reducVal1
=
threadVal1
;
*
reducVal2
=
threadVal2
;
}
template
<
int
ILP
,
typename
scalar_t
,
typename
accscalar_t
,
typename
outscalar_t
,
template
<
typename
,
typename
,
typename
>
class
Epilogue
>
__global__
void
cunn_SoftMaxXEntropyForward
(
accscalar_t
*
losses
,
outscalar_t
*
max_log_sum_exp
,
scalar_t
*
input
,
int64_t
*
labels
,
int64_t
classes
,
const
float
smoothing
)
{
extern
__shared__
unsigned
char
smem
[];
auto
sdata
=
reinterpret_cast
<
accscalar_t
*>
(
smem
);
// forward pointers to batch[blockIdx.x]
// each block handles a sample in the mini-batch
input
+=
blockIdx
.
x
*
classes
;
//output += blockIdx.x * classes;
const
int
shift
=
((
uint64_t
)
input
)
%
ALIGN_BYTES
/
sizeof
(
scalar_t
);
int64_t
label
=
labels
[
blockIdx
.
x
];
// find the max and sum
accscalar_t
threadMax
,
threadSum
,
max_k
,
sum_k
;
ilpReduce
<
MaxFloat
,
AddFloat
,
ILP
,
scalar_t
,
accscalar_t
>
(
shift
,
input
,
classes
,
&
threadMax
,
MaxFloat
<
scalar_t
,
accscalar_t
>
(),
-
at
::
numeric_limits
<
accscalar_t
>::
max
(),
&
threadSum
,
AddFloat
<
scalar_t
,
accscalar_t
>
(),
static_cast
<
accscalar_t
>
(
0
));
blockReduce
<
Max
,
Add
,
accscalar_t
>
(
sdata
,
&
max_k
,
threadMax
,
Max
<
accscalar_t
>
(),
-
at
::
numeric_limits
<
accscalar_t
>::
max
(),
&
sum_k
,
threadSum
,
Add
<
accscalar_t
>
(),
static_cast
<
accscalar_t
>
(
0
));
accscalar_t
threadExp
=
ilpReduce
<
SumExpFloat
,
ILP
,
scalar_t
,
accscalar_t
>
(
shift
,
input
,
classes
,
SumExpFloat
<
scalar_t
,
accscalar_t
>
(
max_k
),
static_cast
<
accscalar_t
>
(
0
));
accscalar_t
sumAll
=
blockReduce
<
Add
,
accscalar_t
>
(
sdata
,
threadExp
,
Add
<
accscalar_t
>
(),
static_cast
<
accscalar_t
>
(
0
));
Epilogue
<
scalar_t
,
accscalar_t
,
outscalar_t
>
epilogue
(
max_k
,
sumAll
);
// calculate per element loss with label smoothing
// reserve max + log_sum_exp for bprop
if
(
threadIdx
.
x
==
0
)
{
accscalar_t
lse
=
max_k
+
std
::
log
(
sumAll
);
if
((
label
>=
0
)
&&
(
label
<
classes
))
{
accscalar_t
log_prob
=
epilogue
(
static_cast
<
accscalar_t
>
(
input
[
label
]));
losses
[
blockIdx
.
x
]
=
(
lse
-
sum_k
/
classes
)
*
smoothing
-
log_prob
*
(
1
-
smoothing
);
}
else
{
losses
[
blockIdx
.
x
]
=
outscalar_t
(
0.
f
);
}
max_log_sum_exp
[
blockIdx
.
x
]
=
lse
;
}
}
template
<
int
ILP
,
typename
scalar_t
,
typename
accscalar_t
,
typename
outscalar_t
>
__device__
__forceinline__
void
apply
(
scalar_t
*
gradInput
,
scalar_t
*
logits
,
outscalar_t
*
max_log_sum_exp
,
outscalar_t
*
gradOutput
,
int64_t
*
labels
,
const
float
smoothing
,
int
classes
)
{
accscalar_t
smooth_positives
=
1.0
-
smoothing
;
accscalar_t
smooth_negatives
=
smoothing
/
classes
;
accscalar_t
tmpGradOutput
=
gradOutput
[
blockIdx
.
x
];
int64_t
label
=
labels
[
blockIdx
.
x
];
accscalar_t
coeff
=
max_log_sum_exp
[
blockIdx
.
x
];
int
offset
=
threadIdx
.
x
;
int
last
=
classes
%
(
ILP
*
blockDim
.
x
);
for
(;
offset
<
classes
-
last
;
offset
+=
blockDim
.
x
*
ILP
)
{
accscalar_t
tmpLogits
[
ILP
];
#pragma unroll
for
(
int
j
=
0
;
j
<
ILP
;
++
j
)
{
tmpLogits
[
j
]
=
static_cast
<
accscalar_t
>
(
logits
[
offset
+
j
*
blockDim
.
x
]);
}
#pragma unroll
for
(
int
j
=
0
;
j
<
ILP
;
++
j
)
gradInput
[
offset
+
j
*
blockDim
.
x
]
=
tmpGradOutput
*
(
std
::
exp
(
tmpLogits
[
j
]
-
coeff
)
-
static_cast
<
accscalar_t
>
(
(
offset
+
j
*
blockDim
.
x
==
label
)
?
1
:
0
)
*
smooth_positives
-
smooth_negatives
);
}
for
(;
offset
<
classes
;
offset
+=
blockDim
.
x
)
gradInput
[
offset
]
=
tmpGradOutput
*
(
std
::
exp
(
static_cast
<
accscalar_t
>
(
logits
[
offset
])
-
coeff
)
-
static_cast
<
accscalar_t
>
((
offset
==
label
)
?
1
:
0
)
*
smooth_positives
-
smooth_negatives
);
}
template
<
int
ILP
,
typename
scalar_t
,
typename
accscalar_t
,
typename
outscalar_t
>
__device__
__forceinline__
void
aligned_apply
(
int
shift
,
scalar_t
*
gradInput
,
scalar_t
*
logits
,
outscalar_t
*
max_log_sum_exp
,
outscalar_t
*
gradOutput
,
int64_t
*
labels
,
const
float
smoothing
,
int
classes
)
{
accscalar_t
smooth_positives
=
1.0
-
smoothing
;
accscalar_t
smooth_negatives
=
smoothing
/
classes
;
accscalar_t
tmpGradOutput
=
gradOutput
[
blockIdx
.
x
];
int64_t
label
=
labels
[
blockIdx
.
x
];
accscalar_t
coeff
=
max_log_sum_exp
[
blockIdx
.
x
];
int
offset
=
threadIdx
.
x
;
// shift and do 1
if
(
shift
>
0
){
logits
-=
shift
;
gradInput
-=
shift
;
classes
+=
shift
;
if
(
threadIdx
.
x
>=
shift
){
gradInput
[
offset
]
=
tmpGradOutput
*
(
std
::
exp
(
static_cast
<
accscalar_t
>
(
logits
[
offset
])
-
coeff
)
-
static_cast
<
accscalar_t
>
(((
offset
-
shift
)
==
label
)
?
1
:
0
)
*
smooth_positives
-
smooth_negatives
);
}
classes
-=
blockDim
.
x
;
gradInput
+=
blockDim
.
x
;
logits
+=
blockDim
.
x
;
shift
-=
blockDim
.
x
;
}
int
last
=
classes
%
(
ILP
*
blockDim
.
x
);
typedef
typename
std
::
aligned_storage
<
ILP
*
sizeof
(
scalar_t
),
ILP
*
alignof
(
scalar_t
)
>::
type
LoadT
;
// input
scalar_t
v
[
ILP
];
LoadT
*
value
=
reinterpret_cast
<
LoadT
*>
(
&
v
);
// output
scalar_t
r
[
ILP
];
LoadT
*
result
=
reinterpret_cast
<
LoadT
*>
(
&
r
);
for
(;
offset
*
ILP
<
(
classes
-
last
);
offset
+=
blockDim
.
x
)
{
*
value
=
reinterpret_cast
<
LoadT
*>
(
logits
)[
offset
];
#pragma unroll
for
(
int
j
=
0
;
j
<
ILP
;
++
j
)
{
r
[
j
]
=
tmpGradOutput
*
(
std
::
exp
(
static_cast
<
accscalar_t
>
(
v
[
j
])
-
coeff
)
-
static_cast
<
accscalar_t
>
(((
ILP
*
offset
+
j
-
shift
)
==
label
)
?
1
:
0
)
*
smooth_positives
-
smooth_negatives
);
}
reinterpret_cast
<
LoadT
*>
(
gradInput
)[
offset
]
=
*
result
;
}
offset
=
classes
-
last
+
threadIdx
.
x
;
for
(;
offset
<
classes
;
offset
+=
blockDim
.
x
)
gradInput
[
offset
]
=
tmpGradOutput
*
(
std
::
exp
(
static_cast
<
accscalar_t
>
(
logits
[
offset
])
-
coeff
)
-
static_cast
<
accscalar_t
>
(((
offset
-
shift
)
==
label
)
?
1
:
0
)
*
smooth_positives
-
smooth_negatives
);
}
template
<
int
ILP
,
typename
scalar_t
,
typename
accscalar_t
,
typename
outscalar_t
,
template
<
typename
,
typename
,
typename
>
class
Epilogue
>
__global__
void
cunn_SoftMaxXEntropyBackward
(
scalar_t
*
gradInput
,
scalar_t
*
logits
,
outscalar_t
*
max_log_sum_exp
,
outscalar_t
*
gradOutput
,
int64_t
*
labels
,
const
float
smoothing
,
int
classes
)
{
gradInput
+=
blockIdx
.
x
*
classes
;
logits
+=
blockIdx
.
x
*
classes
;
// Do vectorized load/store when input/output have same alignment
const
int
shift
=
((
uint64_t
)
logits
)
%
ALIGN_BYTES
/
sizeof
(
scalar_t
);
const
int
shift_
=
((
uint64_t
)
gradInput
)
%
ALIGN_BYTES
/
sizeof
(
scalar_t
);
if
(
shift
==
shift_
){
aligned_apply
<
ILP
,
scalar_t
,
accscalar_t
,
outscalar_t
>
(
shift
,
gradInput
,
logits
,
max_log_sum_exp
,
gradOutput
,
labels
,
smoothing
,
classes
);
}
else
{
apply
<
ILP
,
scalar_t
,
accscalar_t
,
outscalar_t
>
(
gradInput
,
logits
,
max_log_sum_exp
,
gradOutput
,
labels
,
smoothing
,
classes
);
}
}
template
<
template
<
typename
,
typename
,
typename
>
class
Epilogue
>
std
::
vector
<
Tensor
>
host_softmax_xentropy
(
const
Tensor
&
input_
,
const
Tensor
&
labels_
,
const
float
smoothing
){
AT_ASSERTM
(
labels_
.
scalar_type
()
==
ScalarType
::
Long
,
"Label type should be CUDA Long"
);
// Otherwise the kernel will be launched from cuda:0 device
// Cast to char to avoid compiler warning about narrowing
at
::
cuda
::
CUDAGuard
device_guard
{(
char
)
input_
.
get_device
()};
auto
input
=
input_
.
contiguous
();
Tensor
max_log_sum_exp
=
at
::
empty_like
(
labels_
,
input
.
options
().
dtype
(
ScalarType
::
Float
));
Tensor
losses
=
at
::
empty_like
(
labels_
,
input_
.
options
().
dtype
(
ScalarType
::
Float
));
static_assert
(
std
::
is_same
<
acc_type
<
at
::
Half
,
true
>
,
float
>::
value
||
std
::
is_same
<
acc_type
<
at
::
Half
,
true
>
,
double
>::
value
,
"accscalar_t for half should be float or double"
);
AT_ASSERTM
(
input
.
dim
()
==
2
,
"Currently only 2 dim input supported"
);
AT_ASSERTM
(
labels_
.
dim
()
==
1
,
"Labels should be 1 dimensional"
);
AT_ASSERTM
(
input
.
size
(
0
)
==
labels_
.
size
(
0
),
"Input and label should have same number of examples"
);
AT_ASSERTM
(
input
.
numel
()
>
0
,
"Number of classes in input should not be 0"
);
const
int64_t
dim
=
1
;
int64_t
outer_size
=
1
;
int64_t
dim_size
=
input
.
size
(
dim
);
int64_t
inner_size
=
1
;
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
for
(
int64_t
i
=
0
;
i
<
dim
;
++
i
)
outer_size
*=
input
.
size
(
i
);
for
(
int64_t
i
=
dim
+
1
;
i
<
input
.
dim
();
++
i
)
inner_size
*=
input
.
size
(
i
);
// This kernel spawns a block per each element in the batch.
// XXX: it assumes that inner_size == 1
TORCH_CHECK
(
inner_size
==
1
,
"Currently only inner size 1 supported"
);
dim3
grid
(
outer_size
);
using
namespace
at
;
DISPATCH_FLOAT_AND_HALF_AND_BF16
(
input
.
scalar_type
(),
0
,
"host_softmax_xentropy"
,
using
accscalar_t
=
at
::
acc_type
<
scalar_t_0
,
true
>
;
const
int
ILP
=
sizeof
(
float4
)
/
sizeof
(
scalar_t_0
);
dim3
block
=
SoftMax_getBlockSize
(
ILP
,
dim_size
);
cunn_SoftMaxXEntropyForward
<
ILP
,
scalar_t_0
,
accscalar_t
,
accscalar_t
,
Epilogue
>
<<<
grid
,
block
,
2
*
block
.
x
*
sizeof
(
accscalar_t
),
stream
>>>
(
losses
.
data_ptr
<
accscalar_t
>
(),
max_log_sum_exp
.
data_ptr
<
accscalar_t
>
(),
input
.
data_ptr
<
scalar_t_0
>
(),
labels_
.
data_ptr
<
int64_t
>
(),
dim_size
,
smoothing
);
);
C10_CUDA_CHECK
(
cudaGetLastError
());
std
::
vector
<
at
::
Tensor
>
ret
=
{
losses
,
max_log_sum_exp
};
return
ret
;
}
template
<
template
<
typename
,
typename
,
typename
>
class
Epilogue
>
Tensor
host_softmax_xentropy_backward
(
const
at
::
Tensor
&
grad_loss
,
at
::
Tensor
&
logits_
,
const
at
::
Tensor
&
max_log_sum_exp
,
const
at
::
Tensor
&
labels
,
const
float
smoothing
,
bool
inplace
)
{
// Otherwise the kernel will be launched from cuda:0 device
// Cast to char to avoid compiler warning about narrowing
at
::
cuda
::
CUDAGuard
device_guard
{(
char
)
grad_loss
.
get_device
()};
const
int64_t
dim
=
1
;
Tensor
gI
=
inplace
?
logits_
:
at
::
empty_like
(
logits_
);
if
(
grad_loss
.
numel
()
==
0
)
{
return
gI
;
}
auto
grad
=
grad_loss
.
contiguous
();
auto
logits
=
logits_
.
contiguous
();
static_assert
(
std
::
is_same
<
acc_type
<
at
::
Half
,
true
>
,
float
>::
value
||
std
::
is_same
<
acc_type
<
at
::
Half
,
true
>
,
double
>::
value
,
"accscalar_t for half should be float or double"
);
if
(
grad
.
dim
()
==
0
)
grad
=
grad
.
view
(
1
);
AT_ASSERTM
(
logits_
.
dim
()
==
2
,
"Currently only 2 dim input supported"
);
AT_ASSERTM
(
labels
.
dim
()
==
1
,
"Labels should be 1 dimensional"
);
AT_ASSERTM
(
logits_
.
numel
()
>
0
,
"Number of classes in input should not be 0"
);
AT_ASSERTM
(
logits_
.
size
(
0
)
==
labels
.
size
(
0
),
"Input and label should have same number of examples"
);
AT_ASSERTM
(
labels
.
size
(
0
)
==
grad
.
size
(
0
),
"Label and loss should have same number of examples"
);
int64_t
outer_size
=
1
;
int64_t
dim_size
=
logits
.
size
(
dim
);
int64_t
inner_size
=
1
;
for
(
int64_t
i
=
0
;
i
<
dim
;
++
i
)
outer_size
*=
logits
.
size
(
i
);
for
(
int64_t
i
=
dim
+
1
;
i
<
logits
.
dim
();
++
i
)
inner_size
*=
logits
.
size
(
i
);
// See descriptions of kernels above.
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
TORCH_CHECK
(
inner_size
==
1
,
"Currently only inner size 1 supported"
);
dim3
grid
(
outer_size
);
DISPATCH_FLOAT_AND_HALF_AND_BF16
(
gI
.
scalar_type
(),
0
,
"host_softmax_xentropy_backward"
,
using
accscalar_t
=
acc_type
<
scalar_t_0
,
true
>
;
const
int
ILP
=
sizeof
(
float4
)
/
sizeof
(
scalar_t_0
);
dim3
block
=
SoftMax_getBlockSize
(
ILP
,
dim_size
);
cunn_SoftMaxXEntropyBackward
<
ILP
,
scalar_t_0
,
accscalar_t
,
accscalar_t
,
Epilogue
>
<<<
grid
,
block
,
block
.
x
*
sizeof
(
accscalar_t
),
stream
>>>
(
gI
.
data_ptr
<
scalar_t_0
>
(),
logits
.
data_ptr
<
scalar_t_0
>
(),
max_log_sum_exp
.
data_ptr
<
accscalar_t
>
(),
grad
.
data_ptr
<
accscalar_t
>
(),
labels
.
data_ptr
<
int64_t
>
(),
smoothing
,
dim_size
);
);
C10_CUDA_CHECK
(
cudaGetLastError
());
return
gI
;
}
std
::
vector
<
Tensor
>
softmax_xentropy_cuda
(
const
Tensor
&
input
,
const
Tensor
&
labels
,
const
float
smoothing
){
return
host_softmax_xentropy
<
LogSoftMaxForwardEpilogue
>
(
input
,
labels
,
smoothing
);
}
at
::
Tensor
softmax_xentropy_backward_cuda
(
const
at
::
Tensor
&
grad_loss
,
at
::
Tensor
&
logits
,
const
at
::
Tensor
&
max_log_sum_exp
,
const
at
::
Tensor
&
labels
,
const
float
smoothing
,
const
bool
inplace
)
{
AT_ASSERTM
((
grad_loss
.
scalar_type
()
==
ScalarType
::
Float
),
"expected grad types to be at::Float"
);
return
host_softmax_xentropy_backward
<
LogSoftMaxBackwardEpilogue
>
(
grad_loss
,
logits
,
max_log_sum_exp
,
labels
,
smoothing
,
inplace
);
}
flash_attn/losses/cross_entropy_apex.py
0 → 100644
View file @
7c995381
import
torch
import
torch.nn
as
nn
import
xentropy_cuda_lib
# https://github.com/NVIDIA/apex/blob/master/apex/contrib/xentropy/softmax_xentropy.py
class
SoftmaxCrossEntropyLossFn
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
logits
,
labels
,
smoothing
=
0.0
,
padding_idx
=
0
,
inplace_backward
=
False
):
losses
,
max_log_sum_exp
=
xentropy_cuda_lib
.
forward
(
logits
,
labels
,
smoothing
)
losses
.
masked_fill_
(
labels
==
padding_idx
,
0
)
ctx
.
save_for_backward
(
logits
,
max_log_sum_exp
,
labels
)
ctx
.
smoothing
=
smoothing
ctx
.
padding_idx
=
padding_idx
ctx
.
inplace_backward
=
inplace_backward
return
losses
@
staticmethod
def
backward
(
ctx
,
grad_loss
):
logits
,
max_log_sum_exp
,
labels
=
ctx
.
saved_tensors
if
not
grad_loss
.
is_contiguous
():
grad_loss
=
grad_loss
.
contiguous
()
grad_loss
.
masked_fill_
(
labels
==
ctx
.
padding_idx
,
0
)
grad_logits
=
xentropy_cuda_lib
.
backward
(
grad_loss
,
logits
,
max_log_sum_exp
,
labels
,
ctx
.
smoothing
,
ctx
.
inplace_backward
)
return
grad_logits
,
None
,
None
,
None
,
None
class
CrossEntropyLossApex
(
nn
.
Module
):
def
__init__
(
self
,
ignore_index
=-
100
,
reduction
=
'mean'
,
label_smoothing
=
0.0
,
inplace_backward
=
False
):
super
().
__init__
()
if
reduction
not
in
[
'mean'
,
'none'
]:
raise
NotImplementedError
(
"Only support reduction = 'mean' or 'none'"
)
self
.
ignore_index
=
ignore_index
self
.
reduction
=
reduction
self
.
label_smoothing
=
label_smoothing
self
.
inplace_backward
=
inplace_backward
def
forward
(
self
,
input
,
target
):
assert
input
.
is_cuda
and
target
.
is_cuda
# SoftmaxCrossEntropyLoss implicitly casts to float
loss
=
SoftmaxCrossEntropyLossFn
.
apply
(
input
,
target
,
self
.
label_smoothing
,
self
.
ignore_index
,
self
.
inplace_backward
)
if
self
.
reduction
==
'mean'
:
return
loss
.
sum
()
/
(
target
!=
self
.
ignore_index
).
sum
()
else
:
return
loss
flash_attn/losses/cross_entropy_parallel.py
0 → 100644
View file @
7c995381
# Inspired by https://github.com/NVIDIA/apex/blob/master/apex/transformer/tensor_parallel/cross_entropy.py
# But we make it much faster: we compute the local loss and the LSE, and by exchanging the LSE and
# the losses we can get the global loss. There's no need to do it step by step
# (compute local max, exchange, compute exp, compute local sum, exchange, etc.)
import
torch
import
torch.nn
as
nn
import
xentropy_cuda_lib
from
apex.transformer.parallel_state
import
get_tensor_model_parallel_group
from
apex.transformer.parallel_state
import
get_tensor_model_parallel_rank
from
apex.transformer.parallel_state
import
get_tensor_model_parallel_world_size
from
apex.transformer.tensor_parallel.utils
import
VocabUtility
# `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for
# `_all_gather_base` and `_reduce_scatter_base`. They require the most recent
# version of PyTorch. The following 4 lines are for backward comparability with
# older PyTorch.
if
"all_gather_into_tensor"
not
in
dir
(
torch
.
distributed
):
torch
.
distributed
.
all_gather_into_tensor
=
torch
.
distributed
.
_all_gather_base
if
"reduce_scatter_tensor"
not
in
dir
(
torch
.
distributed
):
torch
.
distributed
.
reduce_scatter_tensor
=
torch
.
distributed
.
_reduce_scatter_base
class
SoftmaxCrossEntropyLossParallelFn
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
logits_parallel
,
labels
,
smoothing
=
0.0
,
ignored_index
=-
100
,
inplace_backward
=
False
):
"""
logits_parallel: (batch, vocab_size / world_size)
labels: (batch,)
"""
assert
smoothing
==
0.0
,
'smoothing != 0.0 is not yet implemented, file an issue if you need it'
batch
,
partition_vocab_size
=
logits_parallel
.
shape
assert
labels
.
shape
==
(
batch
,)
rank
=
get_tensor_model_parallel_rank
()
world_size
=
get_tensor_model_parallel_world_size
()
vocab_start_index
,
vocab_end_index
=
VocabUtility
.
vocab_range_from_per_partition_vocab_size
(
partition_vocab_size
,
get_tensor_model_parallel_rank
(),
get_tensor_model_parallel_world_size
()
)
# Create a mask of valid vocab ids (1 means it needs to be masked).
labels_mask
=
(
labels
<
vocab_start_index
)
|
(
labels
>=
vocab_end_index
)
ignored_mask
=
labels
==
ignored_index
labels_local
=
torch
.
where
(
ignored_mask
,
labels
,
labels
-
vocab_start_index
)
masked_labels
=
labels_local
.
clone
()
masked_labels
[
labels_mask
]
=
ignored_index
losses
,
lse_local
=
xentropy_cuda_lib
.
forward
(
logits_parallel
,
masked_labels
,
smoothing
)
assert
lse_local
.
shape
==
(
batch
,)
assert
losses
.
shape
==
(
batch
,)
losses
.
masked_fill_
(
masked_labels
==
ignored_index
,
0
)
if
world_size
>
1
:
lse_allgather
=
torch
.
empty
(
world_size
,
batch
,
dtype
=
lse_local
.
dtype
,
device
=
lse_local
.
device
)
torch
.
distributed
.
all_gather_into_tensor
(
lse_allgather
,
lse_local
.
contiguous
(),
group
=
get_tensor_model_parallel_group
())
lse
=
torch
.
logsumexp
(
lse_allgather
,
dim
=
0
)
torch
.
distributed
.
all_reduce
(
losses
,
op
=
torch
.
distributed
.
ReduceOp
.
SUM
,
group
=
get_tensor_model_parallel_group
())
# The losses are currently lse_local - predicted_logit, we just have to subtract the
# lse_local and add the lse (global).
rank_per_sample
=
labels
//
partition_vocab_size
lse_local
=
lse_allgather
[
rank_per_sample
,
torch
.
arange
(
batch
,
device
=
lse_allgather
.
device
)]
losses
+=
lse
-
lse_local
losses
.
masked_fill_
(
ignored_mask
,
0
)
else
:
lse
=
lse_local
ctx
.
save_for_backward
(
logits_parallel
,
lse
,
labels_local
)
ctx
.
smoothing
=
smoothing
ctx
.
ignored_index
=
ignored_index
ctx
.
inplace_backward
=
inplace_backward
return
losses
@
staticmethod
def
backward
(
ctx
,
grad_loss
):
logits_parallel
,
lse
,
labels
=
ctx
.
saved_tensors
if
not
grad_loss
.
is_contiguous
():
grad_loss
=
grad_loss
.
contiguous
()
grad_loss
.
masked_fill_
(
labels
==
ctx
.
ignored_index
,
0
)
grad_logits
=
xentropy_cuda_lib
.
backward
(
grad_loss
,
logits_parallel
,
lse
,
labels
,
ctx
.
smoothing
,
ctx
.
inplace_backward
)
return
grad_logits
,
None
,
None
,
None
,
None
,
None
class
CrossEntropyLossParallel
(
nn
.
Module
):
def
__init__
(
self
,
ignore_index
=-
100
,
reduction
=
'mean'
,
label_smoothing
=
0.0
,
inplace_backward
=
False
):
super
().
__init__
()
if
reduction
not
in
[
'mean'
,
'none'
]:
raise
NotImplementedError
(
"Only support reduction = 'mean' or 'none'"
)
self
.
ignore_index
=
ignore_index
self
.
reduction
=
reduction
self
.
label_smoothing
=
label_smoothing
self
.
inplace_backward
=
inplace_backward
def
forward
(
self
,
input
,
target
):
assert
input
.
is_cuda
and
target
.
is_cuda
# SoftmaxCrossEntropyLoss implicitly casts to float
loss
=
SoftmaxCrossEntropyLossParallelFn
.
apply
(
input
,
target
,
self
.
label_smoothing
,
self
.
ignore_index
,
self
.
inplace_backward
)
if
self
.
reduction
==
'mean'
:
return
loss
.
sum
()
/
(
target
!=
self
.
ignore_index
).
sum
()
else
:
return
loss
tests/losses/test_cross_entropy_apex.py
0 → 100644
View file @
7c995381
import
math
import
torch
import
torch.nn.functional
as
F
import
pytest
from
einops
import
rearrange
from
src.losses.cross_entropy_apex
import
CrossEntropyLossApex
is_sm8x
=
torch
.
cuda
.
get_device_capability
(
'cuda'
)[
0
]
>=
8
@
pytest
.
mark
.
parametrize
(
'dtype'
,
[
torch
.
float16
,
torch
.
float32
]
+
([
torch
.
bfloat16
]
if
is_sm8x
else
[]))
# @pytest.mark.parametrize('dtype', [torch.float16])
@
pytest
.
mark
.
parametrize
(
'inplace_backward'
,
[
False
,
True
])
# @pytest.mark.parametrize('inplace_backward', [False])
@
pytest
.
mark
.
parametrize
(
'vocab_size'
,
[
50257
])
def
test_cross_entropy_loss_apex
(
vocab_size
,
inplace_backward
,
dtype
):
device
=
'cuda'
rtol
,
atol
=
(
1e-5
,
1e-6
)
if
dtype
==
torch
.
float32
else
(
1e-3
,
1e-4
)
# set seed
torch
.
random
.
manual_seed
(
0
)
batch_size
=
8
seqlen
=
128
x_pt
=
torch
.
randn
(
batch_size
*
seqlen
,
vocab_size
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
x
=
x_pt
.
detach
().
clone
().
requires_grad_
()
y
=
torch
.
randint
(
0
,
vocab_size
,
(
batch_size
*
seqlen
,),
dtype
=
torch
.
long
,
device
=
device
)
y
[
torch
.
randperm
(
batch_size
*
seqlen
)[:
10
]]
=
-
100
model_pt
=
torch
.
nn
.
CrossEntropyLoss
()
model
=
CrossEntropyLossApex
(
inplace_backward
=
inplace_backward
)
out
=
model
(
x
,
y
)
out_pt
=
model_pt
(
x_pt
.
float
(),
y
)
assert
torch
.
allclose
(
out
,
out_pt
,
rtol
=
rtol
,
atol
=
atol
)
g
=
torch
.
randn_like
(
out
)
out_pt
.
backward
(
g
)
out
.
backward
(
g
)
assert
torch
.
allclose
(
x
.
grad
,
x_pt
.
grad
,
rtol
=
rtol
,
atol
=
atol
)
tests/losses/test_cross_entropy_parallel.py
0 → 100644
View file @
7c995381
# Run test with:
# torchrun --no_python --nproc_per_node=8 pytest -q -s tests/losses/test_cross_entropy_parallel.py
import
math
import
torch
import
torch.nn.functional
as
F
import
pytest
from
apex.transformer
import
parallel_state
from
apex.transformer
import
tensor_parallel
from
src.losses.cross_entropy_parallel
import
CrossEntropyLossParallel
is_sm8x
=
torch
.
cuda
.
get_device_capability
(
'cuda'
)[
0
]
>=
8
@
pytest
.
mark
.
parametrize
(
'dtype'
,
[
torch
.
float16
,
torch
.
float32
]
+
([
torch
.
bfloat16
]
if
is_sm8x
else
[]))
# @pytest.mark.parametrize('dtype', [torch.bfloat16])
@
pytest
.
mark
.
parametrize
(
'inplace_backward'
,
[
False
,
True
])
# @pytest.mark.parametrize('inplace_backward', [False])
@
pytest
.
mark
.
parametrize
(
'vocab_size'
,
[
50264
])
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
1
,
2
,
4
,
8
])
# @pytest.mark.parametrize('world_size', [2])
def
test_cross_entropy_loss_apex
(
vocab_size
,
world_size
,
inplace_backward
,
dtype
):
assert
vocab_size
%
world_size
==
0
rtol
,
atol
=
((
1e-5
,
1e-6
)
if
dtype
==
torch
.
float32
else
((
1e-3
,
1e-4
)
if
dtype
==
torch
.
float16
else
(
1e-2
,
3e-3
)))
if
not
torch
.
distributed
.
is_initialized
():
torch
.
distributed
.
init_process_group
(
backend
=
'nccl'
,
init_method
=
'env://'
)
partition_vocab_size
=
vocab_size
//
world_size
device
=
f
'cuda:
{
torch
.
distributed
.
get_rank
()
}
'
assert
world_size
<=
torch
.
distributed
.
get_world_size
()
parallel_state
.
initialize_model_parallel
(
tensor_model_parallel_size_
=
world_size
)
rank
=
parallel_state
.
get_tensor_model_parallel_rank
()
# set seed
torch
.
random
.
manual_seed
(
0
)
batch_size
=
8
seqlen
=
128
x_pt
=
(
torch
.
randn
(
batch_size
*
seqlen
,
vocab_size
,
device
=
device
,
dtype
=
dtype
)
*
10
).
requires_grad_
()
x
=
tensor_parallel
.
scatter_to_tensor_model_parallel_region
(
x_pt
).
detach
().
clone
().
requires_grad_
()
y
=
torch
.
randint
(
0
,
vocab_size
,
(
batch_size
*
seqlen
,),
dtype
=
torch
.
long
,
device
=
device
)
y
[
torch
.
randperm
(
batch_size
*
seqlen
)[:
10
]]
=
-
100
model_pt
=
torch
.
nn
.
CrossEntropyLoss
(
reduction
=
'none'
)
model
=
CrossEntropyLossParallel
(
reduction
=
'none'
,
inplace_backward
=
inplace_backward
)
out
=
model
(
x
,
y
)
out_pt
=
model_pt
(
x_pt
.
float
(),
y
)
assert
torch
.
allclose
(
out
,
out_pt
,
rtol
=
1e-5
,
atol
=
1e-6
)
g
=
torch
.
randn_like
(
out
)
out_pt
.
backward
(
g
)
out
.
backward
(
g
)
assert
torch
.
allclose
(
x
.
grad
,
x_pt
.
grad
[:,
(
rank
*
partition_vocab_size
):(
rank
+
1
)
*
partition_vocab_size
],
rtol
=
rtol
,
atol
=
atol
)
parallel_state
.
destroy_model_parallel
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment