Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
db92ee13
Unverified
Commit
db92ee13
authored
Dec 14, 2021
by
Jithun Nair
Committed by
GitHub
Dec 14, 2021
Browse files
Merge pull request #64 from ROCmSoftwarePlatform/IFU-master-2021-12-08
IFU-master-2021-12-08
parents
d150afdc
68364b49
Changes
98
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
3316 additions
and
2290 deletions
+3316
-2290
apex/__init__.py
apex/__init__.py
+19
-1
apex/_autocast_utils.py
apex/_autocast_utils.py
+9
-0
apex/contrib/csrc/layer_norm/ln.h
apex/contrib/csrc/layer_norm/ln.h
+200
-0
apex/contrib/csrc/layer_norm/ln_api.cpp
apex/contrib/csrc/layer_norm/ln_api.cpp
+199
-59
apex/contrib/csrc/layer_norm/ln_bwd_kernels.cuh
apex/contrib/csrc/layer_norm/ln_bwd_kernels.cuh
+315
-0
apex/contrib/csrc/layer_norm/ln_bwd_semi_cuda_kernel.cu
apex/contrib/csrc/layer_norm/ln_bwd_semi_cuda_kernel.cu
+223
-427
apex/contrib/csrc/layer_norm/ln_fwd_cuda_kernel.cu
apex/contrib/csrc/layer_norm/ln_fwd_cuda_kernel.cu
+120
-177
apex/contrib/csrc/layer_norm/ln_fwd_kernels.cuh
apex/contrib/csrc/layer_norm/ln_fwd_kernels.cuh
+110
-0
apex/contrib/csrc/layer_norm/ln_kernel_traits.h
apex/contrib/csrc/layer_norm/ln_kernel_traits.h
+156
-25
apex/contrib/csrc/layer_norm/ln_utils.cuh
apex/contrib/csrc/layer_norm/ln_utils.cuh
+733
-0
apex/contrib/csrc/layer_norm/utils.cuh
apex/contrib/csrc/layer_norm/utils.cuh
+0
-95
apex/contrib/csrc/multihead_attn/additive_masked_softmax_dropout_cpp.cpp
...rc/multihead_attn/additive_masked_softmax_dropout_cpp.cpp
+49
-66
apex/contrib/csrc/multihead_attn/additive_masked_softmax_dropout_cuda.cu
...rc/multihead_attn/additive_masked_softmax_dropout_cuda.cu
+60
-78
apex/contrib/csrc/multihead_attn/dropout.h
apex/contrib/csrc/multihead_attn/dropout.h
+197
-224
apex/contrib/csrc/multihead_attn/encdec_multihead_attn_cpp.cpp
...contrib/csrc/multihead_attn/encdec_multihead_attn_cpp.cpp
+103
-127
apex/contrib/csrc/multihead_attn/encdec_multihead_attn_cuda.cu
...contrib/csrc/multihead_attn/encdec_multihead_attn_cuda.cu
+141
-163
apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add_cpp.cpp
...src/multihead_attn/encdec_multihead_attn_norm_add_cpp.cpp
+144
-170
apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add_cuda.cu
...src/multihead_attn/encdec_multihead_attn_norm_add_cuda.cu
+157
-184
apex/contrib/csrc/multihead_attn/layer_norm.h
apex/contrib/csrc/multihead_attn/layer_norm.h
+330
-429
apex/contrib/csrc/multihead_attn/masked_softmax_dropout_cpp.cpp
...ontrib/csrc/multihead_attn/masked_softmax_dropout_cpp.cpp
+51
-65
No files found.
apex/__init__.py
View file @
db92ee13
import
logging
# May help avoid undefined symbol errors https://pytorch.org/cppdocs/notes/faq.html#undefined-symbol-errors-from-pytorch-aten
# May help avoid undefined symbol errors https://pytorch.org/cppdocs/notes/faq.html#undefined-symbol-errors-from-pytorch-aten
import
torch
import
torch
import
warnings
if
torch
.
distributed
.
is_available
():
if
torch
.
distributed
.
is_available
():
from
.
import
parallel
from
.
import
parallel
...
@@ -22,3 +24,19 @@ from . import pyprof
...
@@ -22,3 +24,19 @@ from . import pyprof
#common utilties to run tests on ROCm.
#common utilties to run tests on ROCm.
from
.
import
testing
from
.
import
testing
from
.
import
transformer
from
.
import
transformer
# Logging utilities mainly for apex.transformer module
class
RankInfoFormatter
(
logging
.
Formatter
):
def
format
(
self
,
record
):
from
apex.transformer.parallel_state
import
get_rank_info
record
.
rank_info
=
get_rank_info
()
return
super
().
format
(
record
)
_library_root_logger
=
logging
.
getLogger
(
__name__
)
handler
=
logging
.
StreamHandler
()
handler
.
setFormatter
(
RankInfoFormatter
(
"%(asctime)s - %(name)s - %(levelname)s - %(rank_info)s - %(message)s"
))
_library_root_logger
.
addHandler
(
handler
)
_library_root_logger
.
propagate
=
False
apex/_autocast_utils.py
View file @
db92ee13
from
typing
import
Optional
import
torch
import
torch
def
_get_current_dtype
(
dtype
:
Optional
[
torch
.
dtype
]
=
None
)
->
torch
.
dtype
:
if
not
torch
.
is_autocast_enabled
():
return
torch
.
float
or
dtype
else
:
return
torch
.
get_autocast_gpu_dtype
()
def
_cast_if_autocast_enabled
(
*
args
):
def
_cast_if_autocast_enabled
(
*
args
):
if
not
torch
.
is_autocast_enabled
():
if
not
torch
.
is_autocast_enabled
():
return
args
return
args
...
...
apex/contrib/csrc/layer_norm/ln.h
0 → 100644
View file @
db92ee13
#pragma once
#include <unordered_map>
#include <cuda_fp16.h>
#include <cuda_bf16.h>
namespace
layer_norm
{
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Params
>
struct
LaunchParams
{
size_t
workspace_bytes
;
size_t
barrier_size
;
cudaDeviceProp
*
props
;
cudaStream_t
stream
;
Params
params
;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
struct
ParamsBase
{
ParamsBase
()
:
ctas_per_col
(
0
)
,
rows
(
0
)
,
cols
(
0
)
,
x
(
nullptr
)
,
mu
(
nullptr
)
,
rs
(
nullptr
)
,
gamma
(
nullptr
)
,
workspace
(
nullptr
)
,
barrier
(
nullptr
)
{
}
// For Multi-CTA, number of different CTA groups. Otherwise same as gridDim.x.
int
ctas_per_col
;
// Input is interpreted as matrix. We normalize across columns.
int
rows
;
int
cols
;
// Common data pointers.
void
*
x
;
void
*
mu
;
void
*
rs
;
void
*
gamma
;
// Multi-CTA workspace in gmem.
void
*
workspace
;
// Multi-CTA sync barriers in gmem.
int
*
barrier
;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
struct
FwdParams
:
public
ParamsBase
{
FwdParams
()
:
ParamsBase
()
,
z
(
nullptr
)
,
beta
(
nullptr
)
,
epsilon
(
0.
f
)
{
}
// Output of LN FWD.
void
*
z
;
void
*
beta
;
float
epsilon
;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
struct
BwdParams
:
public
ParamsBase
{
BwdParams
()
:
ParamsBase
()
,
dz
(
nullptr
)
,
dbeta_part
(
nullptr
)
,
dgamma_part
(
nullptr
)
,
dx
(
nullptr
)
,
dbeta
(
nullptr
)
,
dgamma
(
nullptr
)
{
}
// Input: gradient wrt. LN FWD output.
void
*
dz
;
// Workspace for Wgrad pre-reduction.
void
*
dbeta_part
;
void
*
dgamma_part
;
// Output: Dgrad.
void
*
dx
;
// Output: Wgrad.
void
*
dbeta
;
void
*
dgamma
;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
using
FwdFunction
=
std
::
function
<
void
(
LaunchParams
<
FwdParams
>&
,
const
bool
)
>
;
using
BwdFunction
=
std
::
function
<
void
(
LaunchParams
<
BwdParams
>&
,
const
bool
)
>
;
using
FunctionKey
=
uint64_t
;
using
FwdRegistry
=
std
::
unordered_map
<
FunctionKey
,
FwdFunction
>
;
using
BwdRegistry
=
std
::
unordered_map
<
FunctionKey
,
BwdFunction
>
;
extern
FwdRegistry
FWD_FUNCS
;
extern
BwdRegistry
BWD_FUNCS
;
////////////////////////////////////////////////////////////////////////////////////////////////////
using
fp32
=
float
;
using
fp16
=
half
;
using
bf16
=
nv_bfloat16
;
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
T
>
struct
TypeId
{};
template
<
>
struct
TypeId
<
fp16
>
{
constexpr
static
uint32_t
Value
=
0
;
};
template
<
>
struct
TypeId
<
bf16
>
{
constexpr
static
uint32_t
Value
=
1
;
};
template
<
>
struct
TypeId
<
fp32
>
{
constexpr
static
uint32_t
Value
=
2
;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
T
,
int
S
>
struct
Type2Key
{
constexpr
static
uint32_t
Value
=
TypeId
<
T
>::
Value
<<
S
;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
T
>
struct
WeightType2Key
:
public
Type2Key
<
T
,
0
>
{};
template
<
typename
T
>
struct
InputType2Key
:
public
Type2Key
<
T
,
2
>
{};
template
<
typename
T
>
struct
OutputType2Key
:
public
Type2Key
<
T
,
4
>
{};
template
<
typename
T
>
struct
ComputeType2Key
:
public
Type2Key
<
T
,
6
>
{};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
W
,
typename
I
,
typename
O
,
typename
C
>
struct
Types2Key
{
constexpr
static
uint32_t
Value
=
WeightType2Key
<
W
>::
Value
|
InputType2Key
<
I
>::
Value
|
OutputType2Key
<
O
>::
Value
|
ComputeType2Key
<
C
>::
Value
;
constexpr
static
inline
uint64_t
get
(
const
uint64_t
hidden_size
){
constexpr
uint64_t
type_key
=
Value
;
return
(
type_key
<<
32
)
|
hidden_size
;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
W
,
typename
I
,
typename
O
,
typename
C
,
uint64_t
HIDDEN_SIZE
>
struct
FwdRegistrar
{
FwdRegistrar
(
FwdFunction
f
){
uint64_t
key
=
Types2Key
<
W
,
I
,
O
,
C
>::
get
(
HIDDEN_SIZE
);
FWD_FUNCS
.
insert
({
key
,
f
});
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
W
,
typename
I
,
typename
O
,
typename
C
,
uint64_t
HIDDEN_SIZE
>
struct
BwdRegistrar
{
BwdRegistrar
(
BwdFunction
f
){
uint64_t
key
=
Types2Key
<
W
,
I
,
O
,
C
>::
get
(
HIDDEN_SIZE
);
BWD_FUNCS
.
insert
({
key
,
f
});
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
}
// namespace layer_norm
apex/contrib/csrc/layer_norm/ln_api.cpp
View file @
db92ee13
#include <torch/extension.h>
#include <torch/extension.h>
#include "ATen/cuda/CUDAContext.h"
#include "ATen/cuda/CUDAContext.h"
void
ln_fwd_cuda
(
at
::
Tensor
&
y
,
at
::
Tensor
&
mu
,
at
::
Tensor
&
rsigma
,
#include "ln.h"
const
at
::
Tensor
&
x
,
const
at
::
Tensor
&
gamma
,
const
at
::
Tensor
&
beta
,
const
float
epsilon
,
const
int
rows
,
const
int
cols
,
cudaStream_t
stream
);
void
ln_bwd_cuda
(
at
::
Tensor
&
dx
,
at
::
Tensor
&
dgamma
,
at
::
Tensor
&
dbeta
,
/*
const
at
::
Tensor
&
dw
,
const
at
::
Tensor
&
x
,
const
at
::
Tensor
&
mu
,
const
at
::
Tensor
&
rsigma
,
const
at
::
Tensor
&
gamma
,
const
int
rows
,
const
int
cols
,
cudaStream_t
stream
);
Supported Type combinations:
input compute weights output
=======================================
fp32 fp32 fp32 fp32
fp16 fp32 fp16 fp16
bf16 fp32 bf16 bf16
fp32 fp32 fp16 fp16
fp32 fp32 bf16 bf16
Remarks:
Output type = Weight type
Compute always in FP32
*/
namespace
layer_norm
{
// Create registries and provide runtime versions of config hash functions.
FwdRegistry
FWD_FUNCS
;
BwdRegistry
BWD_FUNCS
;
////////////////////////////////////////////////////////////////////////////////////////////////////
uint32_t
get_type_id
(
torch
::
Dtype
dtype
){
if
(
dtype
==
torch
::
kFloat16
)
{
return
TypeId
<
fp16
>::
Value
;
}
else
if
(
dtype
==
torch
::
kBFloat16
)
{
return
TypeId
<
bf16
>::
Value
;
}
else
if
(
dtype
==
torch
::
kFloat32
)
{
return
TypeId
<
fp32
>::
Value
;
}
else
{
TORCH_CHECK
(
false
,
"Type not supported: "
,
dtype
);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
uint64_t
get_key
(
torch
::
Dtype
wtype
,
torch
::
Dtype
itype
,
torch
::
Dtype
otype
,
torch
::
Dtype
ctype
,
uint64_t
hidden_size
)
{
using
namespace
layer_norm
;
uint64_t
type_key
=
get_type_id
(
wtype
)
|
(
get_type_id
(
itype
)
<<
2
)
|
(
get_type_id
(
otype
)
<<
4
)
|
(
get_type_id
(
ctype
)
<<
6
);
uint64_t
launcher_key
=
(
type_key
<<
32
)
|
hidden_size
;
return
launcher_key
;
}
}
// namespace layer_norm
////////////////////////////////////////////////////////////////////////////////////////////////////
layer_norm
::
FwdFunction
&
get_fwd_launcher
(
torch
::
Dtype
wtype
,
torch
::
Dtype
itype
,
torch
::
Dtype
otype
,
torch
::
Dtype
ctype
,
uint32_t
hidden_size
)
{
auto
iter
=
layer_norm
::
FWD_FUNCS
.
find
(
layer_norm
::
get_key
(
wtype
,
itype
,
otype
,
ctype
,
hidden_size
));
if
(
iter
!=
layer_norm
::
FWD_FUNCS
.
end
()
)
{
return
iter
->
second
;
}
else
{
TORCH_CHECK
(
false
,
"FWD: Unsupported hidden_size or types: "
,
hidden_size
,
wtype
,
itype
,
otype
,
ctype
);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
layer_norm
::
BwdFunction
&
get_bwd_launcher
(
torch
::
Dtype
wtype
,
torch
::
Dtype
itype
,
torch
::
Dtype
otype
,
torch
::
Dtype
ctype
,
uint32_t
hidden_size
)
{
auto
iter
=
layer_norm
::
BWD_FUNCS
.
find
(
layer_norm
::
get_key
(
wtype
,
itype
,
otype
,
ctype
,
hidden_size
));
if
(
iter
!=
layer_norm
::
BWD_FUNCS
.
end
()
)
{
return
iter
->
second
;
}
else
{
TORCH_CHECK
(
false
,
"BWD: Unsupported hidden_size or types: "
,
hidden_size
,
wtype
,
itype
,
otype
,
ctype
);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
std
::
vector
<
at
::
Tensor
>
ln_fwd
(
const
at
::
Tensor
&
x
,
// BxSxhidden_size
std
::
vector
<
at
::
Tensor
>
ln_fwd
(
const
at
::
Tensor
&
x
,
// BxSxhidden_size
const
at
::
Tensor
&
gamma
,
// hidden_size
const
at
::
Tensor
&
gamma
,
// hidden_size
const
at
::
Tensor
&
beta
,
// hidden_size
const
at
::
Tensor
&
beta
,
// hidden_size
const
float
epsilon
const
float
epsilon
)
{
)
{
auto
itype
=
x
.
scalar_type
();
auto
wtype
=
gamma
.
scalar_type
();
auto
otype
=
wtype
;
auto
ctype
=
torch
::
kFloat32
;
TORCH_CHECK
(
beta
.
scalar_type
()
==
wtype
);
TORCH_CHECK
(
x
.
is_cuda
())
TORCH_CHECK
(
x
.
is_cuda
())
TORCH_CHECK
(
gamma
.
is_cuda
())
TORCH_CHECK
(
gamma
.
is_cuda
())
...
@@ -28,79 +99,148 @@ std::vector<at::Tensor> ln_fwd(const at::Tensor &x, // BxSxhidden_size
...
@@ -28,79 +99,148 @@ std::vector<at::Tensor> ln_fwd(const at::Tensor &x, // BxSxhidden_size
const
int
rows
=
sizes
[
0
];
const
int
rows
=
sizes
[
0
];
const
int
cols
=
sizes
[
1
];
const
int
cols
=
sizes
[
1
];
auto
hidden_size
=
gamma
.
numel
();
auto
dtype
=
x
.
scalar_type
();
TORCH_CHECK
(
gamma
.
dtype
()
==
dtype
);
TORCH_CHECK
(
beta
.
dtype
()
==
dtype
);
TORCH_CHECK
(
gamma
.
sizes
()
==
beta
.
sizes
());
TORCH_CHECK
(
gamma
.
sizes
()
==
beta
.
sizes
());
TORCH_CHECK
(
gamma
.
numel
()
==
cols
);
TORCH_CHECK
(
hidden_size
==
cols
);
TORCH_CHECK
(
epsilon
>=
0.
f
);
TORCH_CHECK
(
epsilon
>=
0.
f
);
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
auto
opts
=
x
.
options
();
auto
y
=
torch
::
empty
_like
(
x
);
auto
z
=
torch
::
empty
(
sizes
,
opts
.
dtype
(
otype
)
);
auto
opts
=
x
.
options
();
auto
mu
=
torch
::
empty
({
rows
},
opts
.
dtype
(
ctype
));
auto
rsigma
=
torch
::
empty
({
rows
},
opts
.
dtype
(
ctype
));
auto
mu
=
torch
::
empty
({
rows
},
opts
.
dtype
(
torch
::
kFloat32
));
layer_norm
::
LaunchParams
<
layer_norm
::
FwdParams
>
launch_params
;
auto
rsigma
=
torch
::
empty
({
rows
},
opts
.
dtype
(
torch
::
kFloat32
));
ln_fwd_cuda
(
y
,
mu
,
rsigma
,
x
,
gamma
,
beta
,
epsilon
,
rows
,
cols
,
stream
);
launch_params
.
props
=
at
::
cuda
::
getCurrentDeviceProperties
();
launch_params
.
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
return
{
y
,
mu
,
rsigma
};
// Request the kernel launcher.
}
auto
launcher
=
get_fwd_launcher
(
wtype
,
itype
,
otype
,
ctype
,
hidden_size
);
// Query the kernel-specific launch parameters.
launcher
(
launch_params
,
true
);
at
::
Tensor
workspace
,
barrier
;
// Set the kernel runtime parameters.
layer_norm
::
FwdParams
&
params
=
launch_params
.
params
;
params
.
rows
=
rows
;
params
.
cols
=
cols
;
params
.
x
=
x
.
data_ptr
();
params
.
mu
=
mu
.
data_ptr
();
params
.
rs
=
rsigma
.
data_ptr
();
params
.
gamma
=
gamma
.
data_ptr
();
params
.
beta
=
beta
.
data_ptr
();
params
.
z
=
z
.
data_ptr
();
params
.
epsilon
=
epsilon
;
if
(
launch_params
.
barrier_size
>
0
)
{
auto
options
=
x
.
options
();
barrier
=
torch
::
zeros
(
launch_params
.
barrier_size
,
options
.
dtype
(
torch
::
kInt32
));
workspace
=
torch
::
empty
(
launch_params
.
workspace_bytes
,
options
.
dtype
(
torch
::
kChar
));
params
.
workspace
=
workspace
.
data_ptr
();
params
.
barrier
=
barrier
.
data_ptr
<
int
>
();
}
// Launch the kernel.
launcher
(
launch_params
,
false
);
return
{
z
,
mu
,
rsigma
};
}
////////////////////////////////////////////////////////////////////////////////////////////////////
std
::
vector
<
at
::
Tensor
>
ln_bwd
(
const
at
::
Tensor
&
d
w
,
// BxSxhidden_size
std
::
vector
<
at
::
Tensor
>
ln_bwd
(
const
at
::
Tensor
&
d
z
,
// BxSxhidden_size
const
at
::
Tensor
&
x
,
// BxSxhidden_size
const
at
::
Tensor
&
x
,
// BxSxhidden_size
const
at
::
Tensor
&
mu
,
// BxS, FP32!
const
at
::
Tensor
&
mu
,
// BxS, FP32!
const
at
::
Tensor
&
rsigma
,
// BxS, FP32!
const
at
::
Tensor
&
rsigma
,
// BxS, FP32!
const
at
::
Tensor
&
gamma
// hidden_size
const
at
::
Tensor
&
gamma
// hidden_size
)
{
)
{
auto
itype
=
x
.
scalar_type
();
auto
wtype
=
gamma
.
scalar_type
();
auto
otype
=
wtype
;
auto
ctype
=
torch
::
kFloat32
;
TORCH_CHECK
(
dz
.
dtype
()
==
otype
);
TORCH_CHECK
(
mu
.
dtype
()
==
ctype
);
TORCH_CHECK
(
rsigma
.
dtype
()
==
ctype
);
TORCH_CHECK
(
x
.
is_cuda
());
TORCH_CHECK
(
x
.
is_cuda
());
TORCH_CHECK
(
d
w
.
is_cuda
());
TORCH_CHECK
(
d
z
.
is_cuda
());
TORCH_CHECK
(
mu
.
is_cuda
());
TORCH_CHECK
(
mu
.
is_cuda
());
TORCH_CHECK
(
rsigma
.
is_cuda
());
TORCH_CHECK
(
rsigma
.
is_cuda
());
TORCH_CHECK
(
gamma
.
is_cuda
());
TORCH_CHECK
(
gamma
.
is_cuda
());
TORCH_CHECK
(
x
.
is_contiguous
());
TORCH_CHECK
(
x
.
is_contiguous
());
TORCH_CHECK
(
d
w
.
is_contiguous
());
TORCH_CHECK
(
d
z
.
is_contiguous
());
auto
sizes
=
x
.
sizes
();
auto
sizes
=
x
.
sizes
();
TORCH_CHECK
(
sizes
.
size
()
==
2
);
TORCH_CHECK
(
sizes
.
size
()
==
2
);
TORCH_CHECK
(
d
w
.
sizes
()
==
sizes
);
TORCH_CHECK
(
d
z
.
sizes
()
==
sizes
);
auto
rows
=
sizes
[
0
];
auto
rows
=
sizes
[
0
];
auto
cols
=
sizes
[
1
];
auto
cols
=
sizes
[
1
];
auto
dtype
=
x
.
scalar_type
();
auto
hidden_size
=
gamma
.
numel
();
TORCH_CHECK
(
dw
.
dtype
()
==
dtype
);
TORCH_CHECK
(
gamma
.
dtype
()
==
dtype
);
TORCH_CHECK
(
mu
.
dtype
()
==
torch
::
kFloat32
);
TORCH_CHECK
(
rsigma
.
dtype
()
==
torch
::
kFloat32
);
TORCH_CHECK
(
mu
.
sizes
()
==
rsigma
.
sizes
());
TORCH_CHECK
(
mu
.
numel
()
==
rows
);
TORCH_CHECK
(
mu
.
numel
()
==
rows
);
TORCH_CHECK
(
mu
.
sizes
()
==
rsigma
.
sizes
());
TORCH_CHECK
(
gamma
.
numel
()
==
cols
);
TORCH_CHECK
(
gamma
.
numel
()
==
cols
);
auto
options
=
x
.
options
();
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
auto
dx
=
torch
::
empty_like
(
x
);
auto
dx
=
torch
::
empty_like
(
x
);
auto
dgamma
=
torch
::
empty_like
(
gamma
);
auto
dgamma
=
torch
::
empty_like
(
gamma
);
auto
dbeta
=
torch
::
empty_like
(
gamma
);
auto
dbeta
=
torch
::
empty_like
(
gamma
);
ln_bwd_cuda
(
dx
,
dgamma
,
dbeta
,
dw
,
x
,
mu
,
rsigma
,
gamma
,
rows
,
cols
,
stream
);
layer_norm
::
LaunchParams
<
layer_norm
::
BwdParams
>
launch_params
;
launch_params
.
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
return
{
dx
,
dgamma
,
dbeta
};
launch_params
.
props
=
at
::
cuda
::
getCurrentDeviceProperties
();
auto
launcher
=
get_bwd_launcher
(
wtype
,
itype
,
otype
,
ctype
,
hidden_size
);
launcher
(
launch_params
,
true
);
auto
dgamma_part
=
torch
::
empty
({
launch_params
.
params
.
ctas_per_col
,
hidden_size
},
options
.
dtype
(
ctype
));
auto
dbeta_part
=
torch
::
empty
({
launch_params
.
params
.
ctas_per_col
,
hidden_size
},
options
.
dtype
(
ctype
));
at
::
Tensor
workspace
,
barrier
;
layer_norm
::
BwdParams
&
params
=
launch_params
.
params
;
params
.
rows
=
rows
;
params
.
cols
=
cols
;
params
.
x
=
x
.
data_ptr
();
params
.
mu
=
mu
.
data_ptr
();
params
.
rs
=
rsigma
.
data_ptr
();
params
.
gamma
=
gamma
.
data_ptr
();
params
.
dz
=
dz
.
data_ptr
();
params
.
dx
=
dx
.
data_ptr
();
params
.
dbeta
=
dbeta
.
data_ptr
();
params
.
dgamma
=
dgamma
.
data_ptr
();
params
.
dbeta_part
=
dbeta_part
.
data_ptr
();
params
.
dgamma_part
=
dgamma_part
.
data_ptr
();
if
(
launch_params
.
barrier_size
>
0
)
{
// TODO Any way to avoid this?
barrier
=
torch
::
zeros
(
launch_params
.
barrier_size
,
options
.
dtype
(
torch
::
kInt32
));
workspace
=
torch
::
empty
(
launch_params
.
workspace_bytes
,
options
.
dtype
(
torch
::
kChar
));
params
.
workspace
=
workspace
.
data_ptr
();
params
.
barrier
=
barrier
.
data_ptr
<
int
>
();
}
launcher
(
launch_params
,
false
);
return
{
dx
,
dgamma
,
dbeta
,
dgamma_part
,
dbeta_part
};
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
doc
()
=
"CUDA LayerNorm"
;
// optional module docstring
m
.
doc
()
=
"CUDA LayerNorm"
;
m
.
def
(
"ln_fwd"
,
&
ln_fwd
,
"Run LayerNorm forward kernel"
);
m
.
def
(
"ln_fwd"
,
&
ln_fwd
,
"Run LayerNorm forward kernel"
);
m
.
def
(
"ln_bwd"
,
&
ln_bwd
,
"Run LayerNorm backward kernel"
);
m
.
def
(
"ln_bwd"
,
&
ln_bwd
,
"Run LayerNorm backward kernel"
);
}
}
apex/contrib/csrc/layer_norm/ln_bwd_kernels.cuh
0 → 100644
View file @
db92ee13
#pragma once
namespace
layer_norm
{
template
<
typename
Ktraits
>
__global__
__launch_bounds__
(
Ktraits
::
THREADS_PER_CTA
)
void
ln_bwd_kernel
(
layer_norm
::
BwdParams
params
)
{
enum
{
ROWS_PER_CTA
=
Ktraits
::
ROWS_PER_CTA
};
enum
{
WARPS_M
=
Ktraits
::
WARPS_M
};
enum
{
WARPS_N
=
Ktraits
::
WARPS_N
};
enum
{
THREADS_PER_ROW
=
Ktraits
::
THREADS_PER_ROW
};
enum
{
COLS
=
Ktraits
::
COLS
};
enum
{
BYTES_PER_ROW
=
Ktraits
::
BYTES_PER_ROW
};
enum
{
LDGS
=
Ktraits
::
LDGS
};
enum
{
NUM_ELTS
=
Ktraits
::
ELTS_PER_LDG
};
enum
{
THREADS_PER_WARP
=
Ktraits
::
THREADS_PER_WARP
};
enum
{
CTAS_PER_ROW
=
Ktraits
::
CTAS_PER_ROW
};
using
compute_t
=
typename
Ktraits
::
compute_t
;
using
index_t
=
typename
Ktraits
::
index_t
;
using
Ivec
=
typename
Ktraits
::
Ivec
;
using
Ovec
=
typename
Ktraits
::
Ovec
;
using
Wvec
=
typename
Ktraits
::
Wvec
;
using
Cvec
=
typename
Ktraits
::
Cvec
;
using
Reducer
=
typename
Ktraits
::
Reducer
;
using
reduce_t
=
typename
Reducer
::
Type
;
extern
__shared__
char
smem_
[];
const
index_t
tidx
=
threadIdx
.
x
;
const
index_t
bidn
=
blockIdx
.
x
%
CTAS_PER_ROW
;
const
index_t
bidm
=
blockIdx
.
x
/
CTAS_PER_ROW
;
const
index_t
lane
=
tidx
%
THREADS_PER_WARP
;
const
index_t
warp
=
tidx
/
THREADS_PER_WARP
;
const
index_t
warp_m
=
warp
/
Ktraits
::
WARPS_N
;
const
index_t
warp_n
=
warp
%
Ktraits
::
WARPS_N
;
const
index_t
tid_r
=
warp_n
*
THREADS_PER_WARP
+
lane
;
const
index_t
r
=
bidm
*
Ktraits
::
ROWS_PER_CTA
+
warp_m
;
const
index_t
c
=
bidn
*
THREADS_PER_ROW
+
warp_n
*
THREADS_PER_WARP
+
lane
;
static_assert
(
COLS
==
THREADS_PER_ROW
*
LDGS
*
NUM_ELTS
*
CTAS_PER_ROW
);
Cvec
dzy_sum
[
LDGS
];
Cvec
dz_sum
[
LDGS
];
memset
(
dzy_sum
,
0
,
sizeof
(
dzy_sum
));
memset
(
dz_sum
,
0
,
sizeof
(
dz_sum
));
compute_t
*
smem_wgrad
=
reinterpret_cast
<
compute_t
*>
(
smem_
);
char
*
smem_dgrad
=
smem_
+
Ktraits
::
SMEM_BYTES_WGRAD
;
Reducer
reducer
(
params
,
bidm
,
bidn
,
warp_m
,
warp_n
,
lane
,
smem_dgrad
);
Sum
<
reduce_t
>
sum
;
constexpr
float
rn
=
1.
f
/
float
(
COLS
);
Wvec
gamma
[
LDGS
];
index_t
idx
=
c
;
#pragma unroll
for
(
int
it
=
0
;
it
<
LDGS
;
it
++
)
{
gamma
[
it
].
load_from
(
params
.
gamma
,
idx
);
idx
+=
Ktraits
::
VEC_COLS_PER_LDG
;
}
// TODO if ROWS_PER_CTA does not divide rows, we might get divergence in the
// last blocks with syncthreads!
// grid stride over rows
#pragma unroll 1
for
(
int
row
=
r
;
row
<
params
.
rows
;
row
+=
params
.
ctas_per_col
*
ROWS_PER_CTA
)
{
const
compute_t
mu_r
=
static_cast
<
const
compute_t
*>
(
params
.
mu
)[
row
];
const
compute_t
rs_r
=
static_cast
<
const
compute_t
*>
(
params
.
rs
)[
row
];
Ivec
x
[
LDGS
];
Ovec
dz
[
LDGS
];
index_t
idx
=
row
*
Ktraits
::
VEC_COLS
+
c
;
#pragma unroll
for
(
int
it
=
0
;
it
<
LDGS
;
it
++
)
{
dz
[
it
].
load_from
(
params
.
dz
,
idx
);
x
[
it
].
load_from
(
params
.
x
,
idx
);
idx
+=
Ktraits
::
VEC_COLS_PER_LDG
;
}
compute_t
dy
[
LDGS
*
NUM_ELTS
];
compute_t
y
[
LDGS
*
NUM_ELTS
];
compute_t
mdy_local
=
0.
f
;
compute_t
mdyy_local
=
0.
f
;
#pragma unroll
for
(
int
it
=
0
;
it
<
LDGS
;
it
++
)
{
#pragma unroll
for
(
int
jt
=
0
;
jt
<
NUM_ELTS
;
jt
++
)
{
compute_t
x_tmp
=
x
[
it
].
data
.
elt
[
jt
];
compute_t
y_tmp
=
rs_r
*
(
x_tmp
-
mu_r
);
compute_t
dy_tmp
=
compute_t
(
gamma
[
it
].
data
.
elt
[
jt
]);
dy_tmp
*=
compute_t
(
dz
[
it
].
data
.
elt
[
jt
]);
compute_t
dz_tmp
=
dz
[
it
].
data
.
elt
[
jt
];
mdy_local
+=
dy_tmp
;
mdyy_local
+=
dy_tmp
*
y_tmp
;
dy
[
it
*
NUM_ELTS
+
jt
]
=
dy_tmp
;
y
[
it
*
NUM_ELTS
+
jt
]
=
y_tmp
;
dzy_sum
[
it
].
data
.
elt
[
jt
]
+=
dz_tmp
*
y_tmp
;
dz_sum
[
it
].
data
.
elt
[
jt
]
+=
dz_tmp
;
}
}
reduce_t
result
=
reducer
.
allreduce
({
mdy_local
,
mdyy_local
},
sum
);
mdy_local
=
layer_norm
::
Get
<
0
>::
of
<
reduce_t
,
compute_t
>
(
result
)
*
rn
;
mdyy_local
=
layer_norm
::
Get
<
1
>::
of
<
reduce_t
,
compute_t
>
(
result
)
*
rn
;
Ivec
dx
[
LDGS
];
idx
=
row
*
Ktraits
::
VEC_COLS
+
c
;
#pragma unroll
for
(
int
it
=
0
;
it
<
LDGS
;
it
++
)
{
#pragma unroll
for
(
int
jt
=
0
;
jt
<
NUM_ELTS
;
jt
++
)
{
compute_t
dy_tmp
=
dy
[
it
*
NUM_ELTS
+
jt
];
compute_t
y_tmp
=
y
[
it
*
NUM_ELTS
+
jt
];
compute_t
dx_tmp
=
rs_r
*
(
dy_tmp
-
(
mdyy_local
*
y_tmp
+
mdy_local
));
dx
[
it
].
data
.
elt
[
jt
]
=
dx_tmp
;
}
dx
[
it
].
store_to
(
params
.
dx
,
idx
);
idx
+=
Ktraits
::
VEC_COLS_PER_LDG
;
}
}
// end: grid stride loop
if
(
WARPS_M
==
1
)
{
idx
=
r
*
Ktraits
::
VEC_COLS
+
c
;
#pragma unroll
for
(
int
it
=
0
;
it
<
LDGS
;
it
++
)
{
dz_sum
[
it
].
store_to
(
params
.
dbeta_part
,
idx
);
dzy_sum
[
it
].
store_to
(
params
.
dgamma_part
,
idx
);
idx
+=
Ktraits
::
VEC_COLS_PER_LDG
;
}
}
else
{
static_assert
(
WARPS_M
==
1
||
Ktraits
::
CTAS_PER_ROW
==
1
,
"Multiple rows per CTA not supported for Multi-CTA."
);
// Finalize reduction of part dgamma and dbeta for this CTA
// by reducing over the rows held across the WARPS_M warps
// Assumption: blockSize divides hidden size.
enum
{
NUM_RES
=
COLS
/
Ktraits
::
THREADS_PER_CTA
};
static_assert
(
NUM_RES
*
Ktraits
::
THREADS_PER_CTA
==
COLS
,
""
);
idx
=
warp_m
*
Ktraits
::
VEC_COLS
+
tid_r
;
#pragma unroll
for
(
int
it
=
0
;
it
<
LDGS
;
it
++
)
{
dz_sum
[
it
].
store_to
(
smem_wgrad
,
idx
);
idx
+=
THREADS_PER_ROW
;
}
__syncthreads
();
compute_t
cta_dz_sum
[
NUM_RES
];
memset
(
cta_dz_sum
,
0
,
sizeof
(
compute_t
)
*
NUM_RES
);
for
(
int
it
=
0
;
it
<
ROWS_PER_CTA
;
it
++
)
{
for
(
int
jt
=
0
;
jt
<
NUM_RES
;
jt
++
)
{
cta_dz_sum
[
jt
]
+=
smem_wgrad
[
it
*
COLS
+
tidx
+
jt
*
Ktraits
::
THREADS_PER_CTA
];
}
}
__syncthreads
();
idx
=
warp_m
*
Ktraits
::
VEC_COLS
+
tid_r
;
#pragma unroll
for
(
int
it
=
0
;
it
<
LDGS
;
it
++
)
{
dzy_sum
[
it
].
store_to
(
smem_wgrad
,
idx
);
idx
+=
THREADS_PER_ROW
;
}
__syncthreads
();
compute_t
cta_dzy_sum
[
NUM_RES
];
memset
(
cta_dzy_sum
,
0
,
sizeof
(
compute_t
)
*
NUM_RES
);
for
(
int
it
=
0
;
it
<
ROWS_PER_CTA
;
it
++
)
{
for
(
int
jt
=
0
;
jt
<
NUM_RES
;
jt
++
)
{
cta_dzy_sum
[
jt
]
+=
smem_wgrad
[
it
*
COLS
+
tidx
+
jt
*
Ktraits
::
THREADS_PER_CTA
];
}
}
compute_t
*
dgamma_part
=
static_cast
<
compute_t
*>
(
params
.
dgamma_part
)
+
bidm
*
COLS
+
tidx
;
for
(
int
jt
=
0
;
jt
<
NUM_RES
;
jt
++
)
{
*
dgamma_part
=
cta_dzy_sum
[
jt
];
dgamma_part
+=
Ktraits
::
THREADS_PER_CTA
;
}
compute_t
*
dbeta_part
=
static_cast
<
compute_t
*>
(
params
.
dbeta_part
)
+
bidm
*
COLS
+
tidx
;
for
(
int
jt
=
0
;
jt
<
NUM_RES
;
jt
++
)
{
*
dbeta_part
=
cta_dz_sum
[
jt
];
dbeta_part
+=
Ktraits
::
THREADS_PER_CTA
;
}
}
}
template
<
typename
Kernel_traits
>
__global__
__launch_bounds__
(
Kernel_traits
::
THREADS_PER_CTA
)
void
ln_bwd_finalize_kernel
(
BwdParams
params
)
{
using
compute_t
=
typename
Kernel_traits
::
compute_t
;
using
weight_t
=
typename
Kernel_traits
::
weight_t
;
using
index_t
=
typename
Kernel_traits
::
index_t
;
using
Reducer
=
typename
Kernel_traits
::
Reducer
;
using
reduce_t
=
typename
Reducer
::
Type
;
Sum
<
reduce_t
>
sum
;
enum
{
NUM_ELT
=
Kernel_traits
::
ELTS_PER_LDG
};
enum
{
THREADS_PER_WARP
=
Kernel_traits
::
THREADS_PER_WARP
};
__shared__
char
smem_
[
Kernel_traits
::
SMEM_BYTES_PER_CTA
];
constexpr
uint32_t
bidm
=
0
;
const
uint32_t
bidn
=
blockIdx
.
x
;
const
uint32_t
tidx
=
threadIdx
.
x
;
const
uint32_t
warp
=
tidx
/
THREADS_PER_WARP
;
const
uint32_t
lane
=
tidx
%
THREADS_PER_WARP
;
Reducer
reducer
(
params
,
bidm
,
bidn
,
0
,
0
,
lane
,
smem_
);
const
uint32_t
c
=
bidn
*
THREADS_PER_WARP
+
lane
;
const
uint32_t
c_out
=
bidn
*
THREADS_PER_WARP
/
2
+
lane
;
constexpr
uint32_t
COL_STRIDE
=
Kernel_traits
::
CTAS
*
THREADS_PER_WARP
;
for
(
uint32_t
col
=
c
,
col_out
=
c_out
;
col
<
Kernel_traits
::
COLS
;
col
+=
COL_STRIDE
,
col_out
+=
COL_STRIDE
/
2
)
{
// Each thread sums over NUM_ELT columns.
Vec
<
compute_t
,
NUM_ELT
>
dbeta_local
,
dgamma_local
;
memset
(
&
dgamma_local
,
0
,
sizeof
(
dgamma_local
));
memset
(
&
dbeta_local
,
0
,
sizeof
(
dbeta_local
));
for
(
uint32_t
row
=
warp
;
row
<
params
.
ctas_per_col
;
row
+=
Kernel_traits
::
ROWS_PER_CTA
)
{
index_t
idx
=
row
*
Kernel_traits
::
COLS
+
col
;
Vec
<
compute_t
,
NUM_ELT
>
dbeta_part
,
dgamma_part
;
dbeta_part
.
load_from
(
params
.
dbeta_part
,
idx
);
dgamma_part
.
load_from
(
params
.
dgamma_part
,
idx
);
#pragma unroll
for
(
int
it
=
0
;
it
<
NUM_ELT
;
it
++
)
{
dgamma_local
.
data
.
elt
[
it
]
+=
dgamma_part
.
data
.
elt
[
it
];
dbeta_local
.
data
.
elt
[
it
]
+=
dbeta_part
.
data
.
elt
[
it
];
}
}
void
*
smem_gamma
=
smem_
;
void
*
smem_beta
=
&
smem_
[
Kernel_traits
::
SMEM_BYTES_TRANSPOSE
];
const
int
write_row
=
warp
;
const
int
write_col
=
lane
^
write_row
;
const
int
write_idx
=
write_row
*
THREADS_PER_WARP
+
write_col
;
dgamma_local
.
store_to
(
smem_gamma
,
write_idx
);
dbeta_local
.
store_to
(
smem_beta
,
write_idx
);
__syncthreads
();
// It would be probably safe to reuse the first row of smem_beta and smem_gamma
void
*
smem_gamma_out
=
&
smem_
[
2
*
Kernel_traits
::
SMEM_BYTES_TRANSPOSE
];
void
*
smem_beta_out
=
&
smem_
[
2
*
Kernel_traits
::
SMEM_BYTES_TRANSPOSE
+
Kernel_traits
::
SMEM_BYTES_OUTPUT
];
// More than one iter iff ROWS_PER_CTA < 32.
for
(
int
w
=
warp
;
w
<
THREADS_PER_WARP
;
w
+=
Kernel_traits
::
ROWS_PER_CTA
)
{
const
int
read_row
=
lane
;
const
int
read_col
=
w
^
read_row
;
const
int
read_idx
=
read_row
*
THREADS_PER_WARP
+
read_col
;
memset
(
&
dbeta_local
,
0
,
sizeof
(
dbeta_local
));
memset
(
&
dgamma_local
,
0
,
sizeof
(
dgamma_local
));
// Load beta and gamma transposed
if
(
read_row
<
Kernel_traits
::
ROWS_PER_CTA
){
dbeta_local
.
load_from
(
smem_beta
,
read_idx
);
dgamma_local
.
load_from
(
smem_gamma
,
read_idx
);
}
// Call reducer on the loaded value(s) and convert.
#pragma unroll
for
(
int
it
=
0
;
it
<
NUM_ELT
;
it
++
)
{
compute_t
b_i
=
dbeta_local
.
data
.
elt
[
it
];
compute_t
g_i
=
dgamma_local
.
data
.
elt
[
it
];
b_i
=
reducer
.
allreduce
(
b_i
,
sum
);
g_i
=
reducer
.
allreduce
(
g_i
,
sum
);
dgamma_local
.
data
.
elt
[
it
]
=
g_i
;
dbeta_local
.
data
.
elt
[
it
]
=
b_i
;
}
// Leader stores the result at the current column.
if
(
lane
==
0
){
dgamma_local
.
store_to
(
smem_gamma_out
,
w
);
dbeta_local
.
store_to
(
smem_beta_out
,
w
);
}
}
// All writes done.
__syncthreads
();
// Pack and store: 2-wide stores with half the threads.
if
(
warp
==
Kernel_traits
::
ROWS_PER_CTA
-
1
&&
lane
<
THREADS_PER_WARP
/
2
)
{
using
src_t
=
typename
TypeToVec2
<
compute_t
>::
Type
;
using
dst_t
=
typename
TypeToVec2
<
weight_t
>::
Type
;
Vec
<
src_t
,
NUM_ELT
>
dbeta_vec2
,
dgamma_vec2
;
Vec
<
dst_t
,
NUM_ELT
>
dbeta_out2
,
dgamma_out2
;
dgamma_vec2
.
load_from
(
smem_gamma_out
,
lane
);
dbeta_vec2
.
load_from
(
smem_beta_out
,
lane
);
#pragma unroll
for
(
int
it
=
0
;
it
<
NUM_ELT
;
it
++
)
{
dgamma_out2
.
data
.
elt
[
it
]
=
Converter
<
src_t
,
dst_t
>::
convert
(
dgamma_vec2
.
data
.
elt
[
it
]);
dbeta_out2
.
data
.
elt
[
it
]
=
Converter
<
src_t
,
dst_t
>::
convert
(
dbeta_vec2
.
data
.
elt
[
it
]);
}
dgamma_out2
.
store_to
(
params
.
dgamma
,
col_out
);
dbeta_out2
.
store_to
(
params
.
dbeta
,
col_out
);
}
}
}
}
// namespace layer_norm
apex/contrib/csrc/layer_norm/ln_bwd_semi_cuda_kernel.cu
View file @
db92ee13
#include "utils.cuh"
#include "ln.h"
#include "ln_utils.cuh"
#include "ln_kernel_traits.h"
#include "ln_kernel_traits.h"
#include "ATen/cuda/CUDAContext.h"
#include "ln_bwd_kernels.cuh"
template
<
typename
Ktraits
>
using
namespace
layer_norm
;
__global__
__launch_bounds__
(
Ktraits
::
THREADS_PER_CTA
)
void
ln_bwd_kernel
(
void
*
__restrict__
dx_
,
void
*
__restrict__
dg_
,
template
<
void
*
__restrict__
db_
,
typename
weight_t
,
const
void
*
__restrict__
dw_
,
typename
input_t
,
const
void
*
__restrict__
x_
,
typename
output_t
,
const
void
*
__restrict__
mu_
,
typename
compute_t
,
const
void
*
__restrict__
rs_
,
typename
index_t
,
const
void
*
__restrict__
g_
,
int
HIDDEN_SIZE
,
const
int
rows
int
CTAS_PER_ROW
,
){
int
WARPS_M
,
using
Vec
=
typename
Ktraits
::
Vec
;
int
WARPS_N
,
int
BYTES_PER_LDG_MAIN
,
enum
{
BYTES_PER_LDG
=
Ktraits
::
BYTES_PER_LDG
};
int
BYTES_PER_LDG_FINAL
enum
{
ROWS_PER_CTA
=
Ktraits
::
ROWS_PER_CTA
};
>
enum
{
WARPS_M
=
Ktraits
::
WARPS_M
};
void
launch_
(
LaunchParams
<
BwdParams
>
&
launch_params
,
const
bool
configure_params
){
enum
{
WARPS_N
=
Ktraits
::
WARPS_N
};
enum
{
THREADS_PER_ROW
=
Ktraits
::
THREADS_PER_ROW
};
using
Kernel_traits
=
Kernel_traits
<
weight_t
,
enum
{
COLS
=
Ktraits
::
COLS
};
input_t
,
enum
{
BYTES_PER_ROW
=
Ktraits
::
BYTES_PER_ROW
};
output_t
,
enum
{
LDGS
=
BYTES_PER_ROW
/
Ktraits
::
BYTES_PER_ROW_PER_CTA
};
compute_t
,
static_assert
(
LDGS
*
Ktraits
::
BYTES_PER_ROW_PER_CTA
==
BYTES_PER_ROW
,
""
);
index_t
,
enum
{
NUM_ELTS
=
Vec
::
NUM_ELTS
};
HIDDEN_SIZE
,
using
vec_t
=
typename
Ktraits
::
vec_t
;
CTAS_PER_ROW
,
using
base_t
=
typename
Ktraits
::
base_t
;
WARPS_M
,
using
compute_t
=
typename
Ktraits
::
compute_t
;
WARPS_N
,
const
int
tidx
=
threadIdx
.
x
;
BYTES_PER_LDG_MAIN
const
int
bidx
=
blockIdx
.
x
;
>
;
const
int
lane
=
tidx
%
THREADS_PER_WARP
;
auto
kernel
=
&
ln_bwd_kernel
<
Kernel_traits
>
;
const
int
warp
=
tidx
/
THREADS_PER_WARP
;
const
int
warp_m
=
warp
/
Ktraits
::
WARPS_N
;
if
(
configure_params
)
{
const
int
warp_n
=
warp
%
Ktraits
::
WARPS_N
;
int
ctas_per_sm
;
const
int
tid_r
=
warp_n
*
THREADS_PER_WARP
+
lane
;
cudaError
status_
=
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
ctas_per_sm
,
kernel
,
Kernel_traits
::
THREADS_PER_CTA
,
Kernel_traits
::
SMEM_BYTES
);
const
int
r
=
bidx
*
Ktraits
::
ROWS_PER_CTA
+
warp_m
;
launch_params
.
params
.
ctas_per_col
=
launch_params
.
props
->
multiProcessorCount
*
ctas_per_sm
/
Kernel_traits
::
CTAS_PER_ROW
;
const
int
c
=
warp_n
*
THREADS_PER_WARP
+
lane
;
launch_params
.
barrier_size
=
0
;
launch_params
.
workspace_bytes
=
0
;
const
char
*
dw_ptr
=
static_cast
<
const
char
*>
(
dw_
);
if
(
Kernel_traits
::
CTAS_PER_ROW
>
1
)
{
const
char
*
x_ptr
=
static_cast
<
const
char
*>
(
x_
);
launch_params
.
barrier_size
=
2
*
launch_params
.
params
.
ctas_per_col
;
const
char
*
g_ptr
=
static_cast
<
const
char
*>
(
g_
);
launch_params
.
workspace_bytes
=
launch_params
.
params
.
ctas_per_col
char
*
dx_ptr
=
static_cast
<
char
*>
(
dx_
);
*
Kernel_traits
::
WARPS_M
const
compute_t
*
mu_ptr
=
static_cast
<
const
compute_t
*>
(
mu_
);
*
Kernel_traits
::
CTAS_PER_ROW
const
compute_t
*
rs_ptr
=
static_cast
<
const
compute_t
*>
(
rs_
);
*
sizeof
(
typename
Kernel_traits
::
reduce_t
)
static_assert
(
COLS
==
THREADS_PER_ROW
*
LDGS
*
NUM_ELTS
,
""
);
*
2
;
}
// smem for final reduction
return
;
//__shared__ compute_t smem_[ROWS_PER_CTA * COLS];
}
extern
__shared__
compute_t
smem_
[];
// static_assert(sizeof(smem_dw_sum) == 32*1024,"");
if
(
Kernel_traits
::
SMEM_BYTES
>=
48
*
1024
)
{
// Using the grid stride loop we can assign multiple rows to each thread
CHECK_CUDA
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
Kernel_traits
::
SMEM_BYTES
));
// by using a number of CTAs smaller than rows / ROWS_PER_CTA
}
// We accumulate them here, one in smem, one in registers, because the smem
auto
stream
=
launch_params
.
stream
;
// capacity is limited compute_t * dw_sum = &smem_dw_sum[warp_m * COLS + tid_r
auto
ctas_per_col
=
launch_params
.
params
.
ctas_per_col
;
// * LDGS * NUM_ELTS];
compute_t
dwy_sum
[
LDGS
*
NUM_ELTS
];
if
(
Kernel_traits
::
CTAS_PER_ROW
==
1
)
{
compute_t
dw_sum
[
LDGS
*
NUM_ELTS
];
kernel
<<<
ctas_per_col
,
Kernel_traits
::
THREADS_PER_CTA
,
Kernel_traits
::
SMEM_BYTES
,
stream
>>>
(
launch_params
.
params
);
memset
(
dwy_sum
,
0
,
sizeof
(
compute_t
)
*
LDGS
*
NUM_ELTS
);
memset
(
dw_sum
,
0
,
sizeof
(
compute_t
)
*
LDGS
*
NUM_ELTS
);
// Debug 8 rows, 4B, 1024 cols
__shared__
compute_t
smem_mdy
[
ROWS_PER_CTA
*
WARPS_N
];
__shared__
compute_t
smem_mdyy
[
ROWS_PER_CTA
*
WARPS_N
];
compute_t
*
mdy_shared
=
&
smem_mdy
[
warp_m
*
WARPS_N
];
compute_t
*
mdyy_shared
=
&
smem_mdyy
[
warp_m
*
WARPS_N
];
constexpr
float
rn
=
1.
f
/
float
(
COLS
);
Vec
gamma
[
LDGS
];
int
col
=
c
;
#pragma unroll
for
(
int
it
=
0
;
it
<
LDGS
;
it
++
)
{
gamma
[
it
].
load_from
(
g_ptr
+
col
*
BYTES_PER_LDG
);
col
+=
Ktraits
::
THREADS_PER_ROW
;
}
// TODO if ROWS_PER_CTA does not divice rows, we might get divergence in the
// last blocks with syncthreads!
// grid stride over rows
#pragma unroll 1
for
(
int
row
=
r
;
row
<
rows
;
row
+=
gridDim
.
x
*
ROWS_PER_CTA
)
{
const
compute_t
mu_r
=
mu_ptr
[
row
];
const
compute_t
rs_r
=
rs_ptr
[
row
];
Vec
dw
[
LDGS
],
x
[
LDGS
],
dx
[
LDGS
];
int
col
=
c
;
#pragma unroll
for
(
int
it
=
0
;
it
<
LDGS
;
it
++
)
{
dw
[
it
].
load_from
(
dw_ptr
+
row
*
BYTES_PER_ROW
+
col
*
BYTES_PER_LDG
);
x
[
it
].
load_from
(
x_ptr
+
row
*
BYTES_PER_ROW
+
col
*
BYTES_PER_LDG
);
col
+=
THREADS_PER_ROW
;
}
// local reductions
compute_t
dy
[
LDGS
*
NUM_ELTS
];
compute_t
y
[
LDGS
*
NUM_ELTS
];
compute_t
mdy_local
=
0.
f
;
compute_t
mdyy_local
=
0.
f
;
#pragma unroll
for
(
int
it
=
0
;
it
<
LDGS
;
it
++
)
{
#pragma unroll
for
(
int
jt
=
0
;
jt
<
Vec
::
NUM_ELTS
;
jt
++
)
{
compute_t
x_tmp
=
x
[
it
].
data
.
elt
[
jt
];
compute_t
y_tmp
=
rs_r
*
(
x_tmp
-
mu_r
);
compute_t
dy_tmp
=
gamma
[
it
].
data
.
elt
[
jt
]
*
dw
[
it
].
data
.
elt
[
jt
];
compute_t
dw_tmp
=
dw
[
it
].
data
.
elt
[
jt
];
mdy_local
+=
dy_tmp
;
mdyy_local
+=
dy_tmp
*
y_tmp
;
dy
[
it
*
NUM_ELTS
+
jt
]
=
dy_tmp
;
y
[
it
*
NUM_ELTS
+
jt
]
=
y_tmp
;
dwy_sum
[
it
*
NUM_ELTS
+
jt
]
+=
dw_tmp
*
y_tmp
;
dw_sum
[
it
*
NUM_ELTS
+
jt
]
+=
dw_tmp
;
}
}
// reduction across row for mdy, mdyy
if
(
WARPS_N
==
1
)
{
// no need to go through smem!
#pragma unroll
for
(
int
it
=
1
;
it
<
THREADS_PER_WARP
;
it
*=
2
)
{
mdy_local
+=
__shfl_xor_sync
(
uint32_t
(
-
1
),
mdy_local
,
it
);
mdyy_local
+=
__shfl_xor_sync
(
uint32_t
(
-
1
),
mdyy_local
,
it
);
}
mdy_local
*=
rn
;
mdyy_local
*=
rn
;
}
else
{
}
else
{
dim3
grid
(
Kernel_traits
::
CTAS_PER_ROW
*
ctas_per_col
);
#pragma unroll
dim3
block
(
Kernel_traits
::
THREADS_PER_CTA
);
for
(
int
it
=
16
;
it
>
0
;
it
/=
2
)
{
void
*
params_
=
(
void
*
)
&
launch_params
.
params
;
mdy_local
+=
__shfl_down_sync
(
uint32_t
(
-
1
),
mdy_local
,
it
);
cudaLaunchCooperativeKernel
((
void
*
)
kernel
,
grid
,
block
,
(
void
**
)
&
params_
,
Kernel_traits
::
SMEM_BYTES
,
stream
);
mdyy_local
+=
__shfl_down_sync
(
uint32_t
(
-
1
),
mdyy_local
,
it
);
}
}
// lane 0 holds the result!
using
Kernel_traits_f
=
layer_norm
::
Kernel_traits_finalize
<
HIDDEN_SIZE
,
if
(
lane
==
0
)
{
weight_t
,
mdy_shared
[
warp_n
]
=
mdy_local
;
input_t
,
mdyy_shared
[
warp_n
]
=
mdyy_local
;
output_t
,
}
compute_t
,
index_t
,
__syncthreads
();
32
*
32
,
// THREADS_PER_CTA
if
(
warp_n
==
0
&&
lane
==
0
)
{
BYTES_PER_LDG_FINAL
>
;
mdy_local
=
0.
f
;
mdyy_local
=
0.
f
;
auto
kernel_f
=
&
layer_norm
::
ln_bwd_finalize_kernel
<
Kernel_traits_f
>
;
for
(
int
it
=
0
;
it
<
WARPS_N
;
it
++
)
{
kernel_f
<<<
Kernel_traits_f
::
CTAS
,
Kernel_traits_f
::
THREADS_PER_CTA
,
0
,
stream
>>>
(
launch_params
.
params
);
mdy_local
+=
mdy_shared
[
it
];
mdyy_local
+=
mdyy_shared
[
it
];
}
mdy_shared
[
0
]
=
mdy_local
;
mdyy_shared
[
0
]
=
mdyy_local
;
}
__syncthreads
();
mdy_local
=
mdy_shared
[
0
]
*
rn
;
mdyy_local
=
mdyy_shared
[
0
]
*
rn
;
}
#pragma unroll
for
(
int
it
=
0
;
it
<
LDGS
;
it
++
)
{
#pragma unroll
for
(
int
jt
=
0
;
jt
<
NUM_ELTS
;
jt
++
)
{
compute_t
dy_tmp
=
dy
[
it
*
NUM_ELTS
+
jt
];
compute_t
y_tmp
=
y
[
it
*
NUM_ELTS
+
jt
];
compute_t
dx_tmp
=
compute_t
(
rs_r
)
*
(
dy_tmp
-
mdyy_local
*
y_tmp
-
mdy_local
);
dx
[
it
].
data
.
elt
[
jt
]
=
dx_tmp
;
}
}
col
=
c
;
#pragma unroll
for
(
int
it
=
0
;
it
<
LDGS
;
it
++
)
{
dx
[
it
].
store_to
(
dx_ptr
+
row
*
BYTES_PER_ROW
+
col
*
BYTES_PER_LDG
);
col
+=
Ktraits
::
THREADS_PER_ROW
;
}
}
// end: grid stride loop
// Finalize reduction of part dgamma and dbeta for this CTA
// by reducing over the rows held across the WARPS_M warps
enum
{
NUM_RES
=
COLS
/
Ktraits
::
THREADS_PER_CTA
};
static_assert
(
NUM_RES
*
Ktraits
::
THREADS_PER_CTA
==
COLS
,
""
);
compute_t
*
smem_write
;
smem_write
=
&
smem_
[
warp_m
*
COLS
+
tid_r
*
NUM_ELTS
];
#pragma unroll
for
(
int
it
=
0
;
it
<
LDGS
;
it
++
)
{
#pragma unroll
for
(
int
jt
=
0
;
jt
<
NUM_ELTS
;
jt
++
)
{
smem_write
[
jt
]
=
dw_sum
[
it
*
NUM_ELTS
+
jt
];
}
smem_write
+=
THREADS_PER_ROW
*
NUM_ELTS
;
}
__syncthreads
();
compute_t
cta_dw_sum
[
NUM_RES
];
memset
(
cta_dw_sum
,
0
,
sizeof
(
compute_t
)
*
NUM_RES
);
for
(
int
it
=
0
;
it
<
ROWS_PER_CTA
;
it
++
)
{
for
(
int
jt
=
0
;
jt
<
NUM_RES
;
jt
++
)
{
cta_dw_sum
[
jt
]
+=
smem_
[
it
*
COLS
+
tidx
+
jt
*
Ktraits
::
THREADS_PER_CTA
];
}
}
__syncthreads
();
smem_write
=
&
smem_
[
warp_m
*
COLS
+
tid_r
*
NUM_ELTS
];
#pragma unroll
for
(
int
it
=
0
;
it
<
LDGS
;
it
++
)
{
#pragma unroll
for
(
int
jt
=
0
;
jt
<
NUM_ELTS
;
jt
++
)
{
smem_write
[
jt
]
=
dwy_sum
[
it
*
NUM_ELTS
+
jt
];
}
smem_write
+=
THREADS_PER_ROW
*
NUM_ELTS
;
}
__syncthreads
();
compute_t
cta_dwy_sum
[
NUM_RES
];
memset
(
cta_dwy_sum
,
0
,
sizeof
(
compute_t
)
*
NUM_RES
);
for
(
int
it
=
0
;
it
<
ROWS_PER_CTA
;
it
++
)
{
for
(
int
jt
=
0
;
jt
<
NUM_RES
;
jt
++
)
{
cta_dwy_sum
[
jt
]
+=
smem_
[
it
*
COLS
+
tidx
+
jt
*
Ktraits
::
THREADS_PER_CTA
];
}
}
compute_t
*
dgamma_part
=
static_cast
<
compute_t
*>
(
dg_
)
+
bidx
*
COLS
+
tidx
;
for
(
int
jt
=
0
;
jt
<
NUM_RES
;
jt
++
)
{
*
dgamma_part
=
cta_dwy_sum
[
jt
];
dgamma_part
+=
Ktraits
::
THREADS_PER_CTA
;
}
compute_t
*
dbeta_part
=
static_cast
<
compute_t
*>
(
db_
)
+
bidx
*
COLS
+
tidx
;
for
(
int
jt
=
0
;
jt
<
NUM_RES
;
jt
++
)
{
*
dbeta_part
=
cta_dw_sum
[
jt
];
dbeta_part
+=
Ktraits
::
THREADS_PER_CTA
;
}
}
}
template
<
typename
Ktraits
,
typename
out_t
>
// Create backward launch function and register. Macro signature:
__global__
__launch_bounds__
(
Ktraits
::
THREADS_PER_CTA
)
void
ln_bwd_finalize_kernel
(
void
*
__restrict__
dg_
,
// HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
void
*
__restrict__
db_
,
const
void
*
__restrict__
dg_part_
,
REGISTER_BWD_LAUNCHER
(
768
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
4
,
1
,
16
,
4
);
const
void
*
__restrict__
db_part_
,
REGISTER_BWD_LAUNCHER
(
768
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
const
int
rows
REGISTER_BWD_LAUNCHER
(
768
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
){
REGISTER_BWD_LAUNCHER
(
768
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
using
Vec
=
typename
Ktraits
::
Vec
;
REGISTER_BWD_LAUNCHER
(
768
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
enum
{
NUM_ELTS
=
Vec
::
NUM_ELTS
};
REGISTER_BWD_LAUNCHER
(
1024
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
1024
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
using
vec_t
=
typename
Ktraits
::
vec_t
;
REGISTER_BWD_LAUNCHER
(
1024
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
using
base_t
=
typename
Ktraits
::
base_t
;
REGISTER_BWD_LAUNCHER
(
1024
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
using
compute_t
=
typename
Ktraits
::
compute_t
;
REGISTER_BWD_LAUNCHER
(
1024
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
enum
{
BYTES_PER_LDG
=
Ktraits
::
BYTES_PER_LDG
};
REGISTER_BWD_LAUNCHER
(
1536
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
4
,
16
,
4
);
enum
{
ROWS_PER_CTA
=
Ktraits
::
ROWS_PER_CTA
};
REGISTER_BWD_LAUNCHER
(
1536
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
4
,
8
,
4
);
enum
{
WARPS_M
=
Ktraits
::
WARPS_M
};
REGISTER_BWD_LAUNCHER
(
1536
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
4
,
16
,
4
);
enum
{
WARPS_N
=
Ktraits
::
WARPS_N
};
REGISTER_BWD_LAUNCHER
(
1536
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
4
,
8
,
4
);
enum
{
THREADS_PER_ROW
=
Ktraits
::
THREADS_PER_ROW
};
REGISTER_BWD_LAUNCHER
(
1536
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
1
,
4
,
16
,
4
);
enum
{
COLS
=
Ktraits
::
COLS
};
enum
{
BYTES_PER_ROW
=
Ktraits
::
BYTES_PER_ROW
};
REGISTER_BWD_LAUNCHER
(
2048
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
4
,
16
,
4
);
enum
{
VEC_COLS
=
BYTES_PER_ROW
/
BYTES_PER_LDG
};
REGISTER_BWD_LAUNCHER
(
2048
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
4
,
16
,
4
);
//dbg
REGISTER_BWD_LAUNCHER
(
2048
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
4
,
16
,
4
);
static_assert
(
VEC_COLS
==
COLS
/
NUM_ELTS
,
""
);
REGISTER_BWD_LAUNCHER
(
2048
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
4
,
16
,
4
);
//static_assert(VEC_COLS == 1024,"");
REGISTER_BWD_LAUNCHER
(
2048
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
1
,
4
,
16
,
4
);
const
int
tidx
=
threadIdx
.
x
;
const
int
bidx
=
blockIdx
.
x
;
REGISTER_BWD_LAUNCHER
(
2304
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
4
,
8
,
4
);
const
int
lane
=
tidx
%
THREADS_PER_WARP
;
REGISTER_BWD_LAUNCHER
(
2304
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
4
,
4
,
4
);
const
int
warp
=
tidx
/
THREADS_PER_WARP
;
REGISTER_BWD_LAUNCHER
(
2304
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
4
,
8
,
4
);
const
int
warp_m
=
warp
/
Ktraits
::
WARPS_N
;
REGISTER_BWD_LAUNCHER
(
2304
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
4
,
4
,
4
);
const
int
warp_n
=
warp
%
Ktraits
::
WARPS_N
;
REGISTER_BWD_LAUNCHER
(
2304
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
1
,
4
,
8
,
4
);
const
int
tid_c
=
warp_n
*
THREADS_PER_WARP
+
lane
;
const
int
c
=
bidx
*
THREADS_PER_ROW
+
tid_c
;
REGISTER_BWD_LAUNCHER
(
3072
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
4
,
16
,
4
);
const
int
r
=
warp_m
;
REGISTER_BWD_LAUNCHER
(
3072
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
3072
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
4
,
16
,
4
);
__shared__
compute_t
smem_
[(
WARPS_M
-
1
)
*
THREADS_PER_ROW
*
NUM_ELTS
];
REGISTER_BWD_LAUNCHER
(
3072
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
3072
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
1
,
4
,
16
,
4
);
//Will probably run this with WARPS_N = 1 and grid = 1024 / (32*4) = 8, or NUM_ELTS=1 and grid = 32
// and WARPS_M = 4 (or 1??)
REGISTER_BWD_LAUNCHER
(
3840
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
4
,
8
,
4
);
for
(
int
col
=
c
;
col
<
VEC_COLS
;
col
+=
gridDim
.
x
*
THREADS_PER_ROW
){
REGISTER_BWD_LAUNCHER
(
3840
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
4
,
4
,
4
);
const
char
*
dg_part_ptr
=
static_cast
<
const
char
*>
(
dg_part_
)
+
r
*
BYTES_PER_ROW
+
col
*
BYTES_PER_LDG
;
REGISTER_BWD_LAUNCHER
(
3840
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
4
,
8
,
4
);
const
char
*
db_part_ptr
=
static_cast
<
const
char
*>
(
db_part_
)
+
r
*
BYTES_PER_ROW
+
col
*
BYTES_PER_LDG
;
REGISTER_BWD_LAUNCHER
(
3840
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
4
,
4
,
4
);
REGISTER_BWD_LAUNCHER
(
3840
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
1
,
4
,
8
,
4
);
compute_t
dg_sum
[
NUM_ELTS
];
compute_t
db_sum
[
NUM_ELTS
];
REGISTER_BWD_LAUNCHER
(
4096
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
4
,
16
,
4
);
memset
(
dg_sum
,
0
,
sizeof
(
compute_t
)
*
NUM_ELTS
);
REGISTER_BWD_LAUNCHER
(
4096
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
4
,
16
,
4
);
memset
(
db_sum
,
0
,
sizeof
(
compute_t
)
*
NUM_ELTS
);
REGISTER_BWD_LAUNCHER
(
4096
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
4
,
16
,
4
);
#pragma unroll
REGISTER_BWD_LAUNCHER
(
4096
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
4
,
16
,
4
);
for
(
int
row
=
r
;
row
<
rows
;
row
+=
ROWS_PER_CTA
){
REGISTER_BWD_LAUNCHER
(
4096
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
1
,
4
,
16
,
4
);
Vec
dg
;
Vec
db
;
REGISTER_BWD_LAUNCHER
(
5120
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
4
,
16
,
4
);
dg
.
load_from
(
dg_part_ptr
);
REGISTER_BWD_LAUNCHER
(
5120
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
4
,
16
,
4
);
db
.
load_from
(
db_part_ptr
);
REGISTER_BWD_LAUNCHER
(
5120
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
4
,
16
,
4
);
dg_part_ptr
+=
ROWS_PER_CTA
*
BYTES_PER_ROW
;
REGISTER_BWD_LAUNCHER
(
5120
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
4
,
16
,
4
);
db_part_ptr
+=
ROWS_PER_CTA
*
BYTES_PER_ROW
;
REGISTER_BWD_LAUNCHER
(
5120
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
1
,
4
,
16
,
4
);
#pragma unroll
REGISTER_BWD_LAUNCHER
(
6144
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
8
,
16
,
4
);
for
(
int
jt
=
0
;
jt
<
NUM_ELTS
;
jt
++
)
{
REGISTER_BWD_LAUNCHER
(
6144
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
8
,
16
,
4
);
dg_sum
[
jt
]
+=
dg
.
data
.
elt
[
jt
];
REGISTER_BWD_LAUNCHER
(
6144
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
8
,
16
,
4
);
db_sum
[
jt
]
+=
db
.
data
.
elt
[
jt
];
REGISTER_BWD_LAUNCHER
(
6144
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
8
,
16
,
4
);
}
REGISTER_BWD_LAUNCHER
(
6144
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
1
,
8
,
16
,
4
);
}
REGISTER_BWD_LAUNCHER
(
8192
,
fp32
,
fp32
,
fp32
,
fp32
,
2
,
1
,
4
,
16
,
4
);
// Finalize the reduction across rows of the CTA
REGISTER_BWD_LAUNCHER
(
8192
,
fp16
,
fp16
,
fp16
,
fp32
,
2
,
1
,
4
,
16
,
4
);
compute_t
*
smem_write
;
REGISTER_BWD_LAUNCHER
(
8192
,
fp16
,
fp32
,
fp16
,
fp32
,
2
,
1
,
4
,
16
,
4
);
smem_write
=
smem_
+
(
warp_m
-
1
)
*
THREADS_PER_ROW
*
NUM_ELTS
+
tid_c
;
REGISTER_BWD_LAUNCHER
(
8192
,
bf16
,
bf16
,
bf16
,
fp32
,
2
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
8192
,
bf16
,
fp32
,
bf16
,
fp32
,
2
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
10240
,
fp32
,
fp32
,
fp32
,
fp32
,
2
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
10240
,
fp16
,
fp16
,
fp16
,
fp32
,
2
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
10240
,
fp16
,
fp32
,
fp16
,
fp32
,
2
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
10240
,
bf16
,
bf16
,
bf16
,
fp32
,
2
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
10240
,
bf16
,
fp32
,
bf16
,
fp32
,
2
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
12288
,
fp32
,
fp32
,
fp32
,
fp32
,
4
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
12288
,
fp16
,
fp16
,
fp16
,
fp32
,
4
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
12288
,
fp16
,
fp32
,
fp16
,
fp32
,
4
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
12288
,
bf16
,
bf16
,
bf16
,
fp32
,
4
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
12288
,
bf16
,
fp32
,
bf16
,
fp32
,
4
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
12800
,
fp32
,
fp32
,
fp32
,
fp32
,
5
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
12800
,
fp16
,
fp16
,
fp16
,
fp32
,
5
,
1
,
4
,
8
,
4
);
REGISTER_BWD_LAUNCHER
(
12800
,
fp16
,
fp32
,
fp16
,
fp32
,
5
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
12800
,
bf16
,
bf16
,
bf16
,
fp32
,
5
,
1
,
4
,
8
,
4
);
REGISTER_BWD_LAUNCHER
(
12800
,
bf16
,
fp32
,
bf16
,
fp32
,
5
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
15360
,
fp32
,
fp32
,
fp32
,
fp32
,
4
,
1
,
4
,
8
,
4
);
REGISTER_BWD_LAUNCHER
(
15360
,
fp16
,
fp16
,
fp16
,
fp32
,
4
,
1
,
4
,
4
,
4
);
REGISTER_BWD_LAUNCHER
(
15360
,
fp16
,
fp32
,
fp16
,
fp32
,
4
,
1
,
4
,
8
,
4
);
REGISTER_BWD_LAUNCHER
(
15360
,
bf16
,
bf16
,
bf16
,
fp32
,
4
,
1
,
4
,
4
,
4
);
REGISTER_BWD_LAUNCHER
(
15360
,
bf16
,
fp32
,
bf16
,
fp32
,
4
,
1
,
4
,
8
,
4
);
REGISTER_BWD_LAUNCHER
(
16384
,
fp32
,
fp32
,
fp32
,
fp32
,
4
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
16384
,
fp16
,
fp16
,
fp16
,
fp32
,
4
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
16384
,
fp16
,
fp32
,
fp16
,
fp32
,
4
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
16384
,
bf16
,
bf16
,
bf16
,
fp32
,
4
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
16384
,
bf16
,
fp32
,
bf16
,
fp32
,
4
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
18432
,
fp32
,
fp32
,
fp32
,
fp32
,
4
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
18432
,
fp16
,
fp16
,
fp16
,
fp32
,
4
,
1
,
4
,
8
,
4
);
REGISTER_BWD_LAUNCHER
(
18432
,
fp16
,
fp32
,
fp16
,
fp32
,
4
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
18432
,
bf16
,
bf16
,
bf16
,
fp32
,
4
,
1
,
4
,
8
,
4
);
REGISTER_BWD_LAUNCHER
(
18432
,
bf16
,
fp32
,
bf16
,
fp32
,
4
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
20480
,
fp32
,
fp32
,
fp32
,
fp32
,
4
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
20480
,
fp16
,
fp16
,
fp16
,
fp32
,
4
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
20480
,
fp16
,
fp32
,
fp16
,
fp32
,
4
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
20480
,
bf16
,
bf16
,
bf16
,
fp32
,
4
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
20480
,
bf16
,
fp32
,
bf16
,
fp32
,
4
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
24576
,
fp32
,
fp32
,
fp32
,
fp32
,
4
,
1
,
8
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
24576
,
fp16
,
fp16
,
fp16
,
fp32
,
4
,
1
,
8
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
24576
,
fp16
,
fp32
,
fp16
,
fp32
,
4
,
1
,
8
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
24576
,
bf16
,
bf16
,
bf16
,
fp32
,
4
,
1
,
8
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
24576
,
bf16
,
fp32
,
bf16
,
fp32
,
4
,
1
,
8
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
25600
,
fp32
,
fp32
,
fp32
,
fp32
,
5
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
25600
,
fp16
,
fp16
,
fp16
,
fp32
,
5
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
25600
,
fp16
,
fp32
,
fp16
,
fp32
,
5
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
25600
,
bf16
,
bf16
,
bf16
,
fp32
,
5
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
25600
,
bf16
,
fp32
,
bf16
,
fp32
,
5
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
30720
,
fp32
,
fp32
,
fp32
,
fp32
,
4
,
1
,
8
,
8
,
4
);
REGISTER_BWD_LAUNCHER
(
30720
,
fp16
,
fp16
,
fp16
,
fp32
,
4
,
1
,
8
,
4
,
4
);
REGISTER_BWD_LAUNCHER
(
30720
,
fp16
,
fp32
,
fp16
,
fp32
,
4
,
1
,
8
,
8
,
4
);
REGISTER_BWD_LAUNCHER
(
30720
,
bf16
,
bf16
,
bf16
,
fp32
,
4
,
1
,
8
,
4
,
4
);
REGISTER_BWD_LAUNCHER
(
30720
,
bf16
,
fp32
,
bf16
,
fp32
,
4
,
1
,
8
,
8
,
4
);
REGISTER_BWD_LAUNCHER
(
32768
,
fp32
,
fp32
,
fp32
,
fp32
,
4
,
1
,
8
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
32768
,
fp16
,
fp16
,
fp16
,
fp32
,
4
,
1
,
8
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
32768
,
fp16
,
fp32
,
fp16
,
fp32
,
4
,
1
,
8
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
32768
,
bf16
,
bf16
,
bf16
,
fp32
,
4
,
1
,
8
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
32768
,
bf16
,
fp32
,
bf16
,
fp32
,
4
,
1
,
8
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
40960
,
fp32
,
fp32
,
fp32
,
fp32
,
4
,
1
,
8
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
40960
,
fp16
,
fp16
,
fp16
,
fp32
,
4
,
1
,
8
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
40960
,
fp16
,
fp32
,
fp16
,
fp32
,
4
,
1
,
8
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
40960
,
bf16
,
bf16
,
bf16
,
fp32
,
4
,
1
,
8
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
40960
,
bf16
,
fp32
,
bf16
,
fp32
,
4
,
1
,
8
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
49152
,
fp32
,
fp32
,
fp32
,
fp32
,
8
,
1
,
8
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
49152
,
fp16
,
fp16
,
fp16
,
fp32
,
8
,
1
,
8
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
49152
,
fp16
,
fp32
,
fp16
,
fp32
,
8
,
1
,
8
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
49152
,
bf16
,
bf16
,
bf16
,
fp32
,
8
,
1
,
8
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
49152
,
bf16
,
fp32
,
bf16
,
fp32
,
8
,
1
,
8
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
65536
,
fp32
,
fp32
,
fp32
,
fp32
,
8
,
1
,
8
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
65536
,
fp16
,
fp16
,
fp16
,
fp32
,
8
,
1
,
8
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
65536
,
fp16
,
fp32
,
fp16
,
fp32
,
8
,
1
,
8
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
65536
,
bf16
,
bf16
,
bf16
,
fp32
,
8
,
1
,
8
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
65536
,
bf16
,
fp32
,
bf16
,
fp32
,
8
,
1
,
8
,
16
,
4
);
if
(
warp_m
>
0
)
{
#pragma unroll
for
(
int
jt
=
0
;
jt
<
NUM_ELTS
;
jt
++
)
{
*
smem_write
=
dg_sum
[
jt
];
smem_write
+=
THREADS_PER_ROW
;
}
}
__syncthreads
();
compute_t
*
smem_read
;
smem_read
=
smem_
+
tid_c
;
if
(
warp_m
==
0
)
{
#pragma unroll
for
(
int
it
=
0
;
it
<
WARPS_M
-
1
;
it
++
)
{
#pragma unroll
for
(
int
jt
=
0
;
jt
<
NUM_ELTS
;
jt
++
)
{
dg_sum
[
jt
]
+=
*
smem_read
;
smem_read
+=
THREADS_PER_ROW
;
}
}
}
__syncthreads
();
smem_write
=
smem_
+
(
warp_m
-
1
)
*
THREADS_PER_ROW
*
NUM_ELTS
+
tid_c
;
if
(
warp_m
>
0
)
{
#pragma unroll
for
(
int
jt
=
0
;
jt
<
NUM_ELTS
;
jt
++
)
{
*
smem_write
=
db_sum
[
jt
];
smem_write
+=
THREADS_PER_ROW
;
}
}
__syncthreads
();
smem_read
=
smem_
+
tid_c
;
if
(
warp_m
==
0
)
{
#pragma unroll
for
(
int
it
=
0
;
it
<
WARPS_M
-
1
;
it
++
)
{
#pragma unroll
for
(
int
jt
=
0
;
jt
<
NUM_ELTS
;
jt
++
)
{
db_sum
[
jt
]
+=
*
smem_read
;
smem_read
+=
THREADS_PER_ROW
;
}
}
using
vout_t
=
typename
Vec_type
<
sizeof
(
out_t
)
*
NUM_ELTS
>::
Type
;
union
{
vout_t
raw
;
out_t
elt
[
NUM_ELTS
];
}
dg_out
,
db_out
;
// out_t dg_out[NUM_ELTS], db_out[NUM_ELTS];
#pragma unroll
for
(
int
jt
=
0
;
jt
<
NUM_ELTS
;
jt
++
)
{
dg_out
.
elt
[
jt
]
=
dg_sum
[
jt
];
db_out
.
elt
[
jt
]
=
db_sum
[
jt
];
}
vout_t
*
dg_ptr
=
reinterpret_cast
<
vout_t
*>
(
dg_
)
+
col
;
vout_t
*
db_ptr
=
reinterpret_cast
<
vout_t
*>
(
db_
)
+
col
;
*
dg_ptr
=
dg_out
.
raw
;
*
db_ptr
=
db_out
.
raw
;
}
}
}
template
<
typename
scalar_t
>
void
launch
(
at
::
Tensor
&
dx
,
at
::
Tensor
&
dgamma
,
at
::
Tensor
&
dbeta
,
at
::
Tensor
&
dgamma_part
,
at
::
Tensor
&
dbeta_part
,
const
at
::
Tensor
&
dw
,
const
at
::
Tensor
&
x
,
const
at
::
Tensor
&
mu
,
const
at
::
Tensor
&
rsigma
,
const
at
::
Tensor
&
gamma
,
const
int
rows
,
const
int
cols
,
const
int
gridx
,
cudaStream_t
stream
){
if
(
cols
==
1024
)
{
using
Ktraits
=
Kernel_traits
<
scalar_t
,
1024
,
4
,
1
>
;
if
(
Ktraits
::
SMEM_BYTES
>=
48
*
1024
)
{
AT_CUDA_CHECK
(
cudaFuncSetAttribute
(
ln_bwd_kernel
<
Ktraits
>
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
Ktraits
::
SMEM_BYTES
));
}
ln_bwd_kernel
<
Ktraits
>
<<<
gridx
,
Ktraits
::
THREADS_PER_CTA
,
Ktraits
::
SMEM_BYTES
,
stream
>>>
(
dx
.
data_ptr
(),
dgamma_part
.
data_ptr
(),
dbeta_part
.
data_ptr
(),
dw
.
data_ptr
(),
x
.
data_ptr
(),
mu
.
data_ptr
(),
rsigma
.
data_ptr
(),
gamma
.
data_ptr
(),
rows
);
using
Ktraits2
=
Kernel_traits
<
float
,
1024
,
16
,
1
,
4
>
;
constexpr
int
grid2
=
DIVUP
(
1024
,
Ktraits2
::
THREADS_PER_ROW
*
Ktraits2
::
Vec
::
NUM_ELTS
);
ln_bwd_finalize_kernel
<
Ktraits2
,
scalar_t
>
<<<
grid2
,
Ktraits2
::
THREADS_PER_CTA
,
0
,
stream
>>>
(
dgamma
.
data_ptr
(),
dbeta
.
data_ptr
(),
dgamma_part
.
data_ptr
(),
dbeta_part
.
data_ptr
(),
gridx
);
}
else
{
assert
(
false
&&
"Not implemented"
);
}
AT_CUDA_CHECK
(
cudaPeekAtLastError
());
}
void
ln_bwd_cuda
(
at
::
Tensor
&
dx
,
at
::
Tensor
&
dgamma
,
at
::
Tensor
&
dbeta
,
const
at
::
Tensor
&
dw
,
const
at
::
Tensor
&
x
,
const
at
::
Tensor
&
mu
,
const
at
::
Tensor
&
rsigma
,
const
at
::
Tensor
&
gamma
,
const
int
rows
,
const
int
cols
,
cudaStream_t
stream
)
{
const
auto
dtype
=
x
.
scalar_type
();
const
auto
props
=
at
::
cuda
::
getCurrentDeviceProperties
();
const
int
smCount
=
props
->
multiProcessorCount
;
// Launch 2 CTAs per SM
const
int
grid
=
2
*
smCount
;
//request workspace for two-step reduction. We always reduce in FP32.
auto
opts
=
x
.
options
();
auto
dbeta_part
=
torch
::
empty
({
grid
,
cols
},
opts
.
dtype
(
torch
::
kFloat32
));
auto
dgamma_part
=
torch
::
empty
({
grid
,
cols
},
opts
.
dtype
(
torch
::
kFloat32
));
if
(
dtype
==
torch
::
kFloat16
)
{
launch
<
half
>
(
dx
,
dgamma
,
dbeta
,
dgamma_part
,
dbeta_part
,
dw
,
x
,
mu
,
rsigma
,
gamma
,
rows
,
cols
,
grid
,
stream
);
}
else
if
(
dtype
==
torch
::
kFloat32
)
{
launch
<
float
>
(
dx
,
dgamma
,
dbeta
,
dgamma_part
,
dbeta_part
,
dw
,
x
,
mu
,
rsigma
,
gamma
,
rows
,
cols
,
grid
,
stream
);
}
else
{
assert
(
false
&&
"Not implemented"
);
}
}
\ No newline at end of file
apex/contrib/csrc/layer_norm/ln_fwd_cuda_kernel.cu
View file @
db92ee13
#include "utils.cuh"
#include "ln.h"
#include "ln_utils.cuh"
#include "ln_kernel_traits.h"
#include "ln_kernel_traits.h"
#include "ATen/cuda/CUDAContext.h"
#include "ln_fwd_kernels.cuh"
template
<
typename
Ktraits
>
using
namespace
layer_norm
;
__global__
__launch_bounds__
(
Ktraits
::
THREADS_PER_CTA
)
void
ln_fwd_kernel
(
void
*
__restrict__
y_
,
void
*
__restrict__
mu_
,
void
*
__restrict__
rsigma_
,
template
<
const
void
*
__restrict__
x_
,
const
void
*
__restrict__
gamma_
,
typename
weight_t
,
const
void
*
__restrict__
beta_
,
const
float
epsilon
,
int
rows
)
{
typename
input_t
,
typename
output_t
,
using
Vec
=
typename
Ktraits
::
Vec
;
typename
compute_t
,
typename
index_t
,
using
base_t
=
typename
Ktraits
::
base_t
;
int
HIDDEN_SIZE
,
using
compute_t
=
typename
Ktraits
::
compute_t
;
int
CTAS_PER_ROW
,
enum
{
NUM_ELTS
=
Vec
::
NUM_ELTS
};
int
WARPS_M
,
enum
{
WARPS_N
=
Ktraits
::
WARPS_N
};
int
WARPS_N
,
enum
{
WARPS_M
=
Ktraits
::
WARPS_M
};
int
BYTES_PER_LDG
enum
{
ROWS_PER_CTA
=
Ktraits
::
ROWS_PER_CTA
};
>
void
launch_
(
LaunchParams
<
FwdParams
>
&
launch_params
,
const
bool
configure_params
){
enum
{
THREADS_PER_ROW
=
Ktraits
::
THREADS_PER_ROW
};
enum
{
BYTES_PER_LDG
=
Ktraits
::
BYTES_PER_LDG
};
using
Kernel_traits
=
Kernel_traits
<
weight_t
,
static_assert
(
BYTES_PER_LDG
==
16
,
""
);
input_t
,
output_t
,
enum
{
BYTES_PER_ROW
=
Ktraits
::
BYTES_PER_ROW
};
compute_t
,
enum
{
LDGS
=
BYTES_PER_ROW
/
Ktraits
::
BYTES_PER_ROW_PER_CTA
};
index_t
,
static_assert
(
LDGS
*
Ktraits
::
BYTES_PER_ROW_PER_CTA
==
BYTES_PER_ROW
,
""
);
HIDDEN_SIZE
,
CTAS_PER_ROW
,
const
int
tidx
=
threadIdx
.
x
;
WARPS_M
,
const
int
bidx
=
blockIdx
.
x
;
WARPS_N
,
const
int
lane
=
tidx
%
THREADS_PER_WARP
;
BYTES_PER_LDG
const
int
warp
=
tidx
/
THREADS_PER_WARP
;
>
;
const
int
warp_n
=
warp
%
WARPS_N
;
auto
kernel
=
&
ln_fwd_kernel
<
Kernel_traits
>
;
const
int
warp_m
=
warp
/
WARPS_N
;
if
(
configure_params
)
{
const
int
c
=
warp_n
*
THREADS_PER_WARP
+
lane
;
int
ctas_per_sm
;
const
int
r
=
bidx
*
ROWS_PER_CTA
+
warp_m
;
cudaError
status_
=
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
ctas_per_sm
,
kernel
,
Kernel_traits
::
THREADS_PER_CTA
,
Kernel_traits
::
SMEM_BYTES_FWD
);
const
char
*
x_ptr
=
static_cast
<
const
char
*>
(
x_
);
launch_params
.
params
.
ctas_per_col
=
launch_params
.
props
->
multiProcessorCount
*
ctas_per_sm
/
Kernel_traits
::
CTAS_PER_ROW
;
launch_params
.
barrier_size
=
0
;
const
char
*
g_ptr
=
static_cast
<
const
char
*>
(
gamma_
);
launch_params
.
workspace_bytes
=
0
;
const
char
*
b_ptr
=
static_cast
<
const
char
*>
(
beta_
);
if
(
Kernel_traits
::
CTAS_PER_ROW
>
1
)
{
launch_params
.
barrier_size
=
2
*
launch_params
.
params
.
ctas_per_col
;
char
*
y_ptr
=
static_cast
<
char
*>
(
y_
);
launch_params
.
workspace_bytes
=
launch_params
.
params
.
ctas_per_col
compute_t
*
mu_ptr
=
static_cast
<
compute_t
*>
(
mu_
);
*
Kernel_traits
::
WARPS_M
compute_t
*
rs_ptr
=
static_cast
<
compute_t
*>
(
rsigma_
);
*
Kernel_traits
::
CTAS_PER_ROW
*
sizeof
(
typename
Kernel_traits
::
Stats
::
stats_t
)
Vec
gamma
[
LDGS
];
*
2
;
Vec
beta
[
LDGS
];
}
#pragma unroll
return
;
for
(
int
it
=
0
,
col
=
c
;
it
<
LDGS
;
it
++
)
{
}
gamma
[
it
].
load_from
(
g_ptr
+
col
*
BYTES_PER_LDG
);
beta
[
it
].
load_from
(
b_ptr
+
col
*
BYTES_PER_LDG
);
if
(
Kernel_traits
::
SMEM_BYTES_FWD
>=
48
*
1024
)
{
col
+=
THREADS_PER_ROW
;
CHECK_CUDA
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
Kernel_traits
::
SMEM_BYTES_FWD
));
}
}
auto
stream
=
launch_params
.
stream
;
constexpr
compute_t
rn
=
1.
f
/
compute_t
(
Ktraits
::
COLS
);
auto
ctas_per_col
=
launch_params
.
params
.
ctas_per_col
;
for
(
int
row
=
r
;
row
<
rows
;
row
+=
gridDim
.
x
*
ROWS_PER_CTA
)
{
Vec
x
[
LDGS
];
if
(
Kernel_traits
::
CTAS_PER_ROW
==
1
)
{
#pragma unroll
kernel
<<<
ctas_per_col
,
Kernel_traits
::
THREADS_PER_CTA
,
Kernel_traits
::
SMEM_BYTES_FWD
,
stream
>>>
(
launch_params
.
params
);
for
(
int
it
=
0
,
col
=
c
;
it
<
LDGS
;
it
++
)
{
x
[
it
].
load_from
(
x_ptr
+
row
*
BYTES_PER_ROW
+
col
*
BYTES_PER_LDG
);
col
+=
THREADS_PER_ROW
;
}
compute_t
xf
[
LDGS
*
NUM_ELTS
];
#pragma unroll
for
(
int
it
=
0
;
it
<
LDGS
;
it
++
)
{
#pragma unroll
for
(
int
jt
=
0
;
jt
<
NUM_ELTS
;
jt
++
)
{
xf
[
it
*
NUM_ELTS
+
jt
]
=
compute_t
(
x
[
it
].
data
.
elt
[
jt
]);
}
}
compute_t
mu_local
=
0.
f
;
#pragma unroll
for
(
int
it
=
0
;
it
<
LDGS
;
it
++
)
{
#pragma unroll
for
(
int
jt
=
0
;
jt
<
NUM_ELTS
;
jt
++
)
{
mu_local
+=
xf
[
it
*
NUM_ELTS
+
jt
];
}
}
#pragma unroll
for
(
int
it
=
1
;
it
<
THREADS_PER_WARP
;
it
*=
2
)
{
mu_local
+=
__shfl_xor_sync
(
uint32_t
(
-
1
),
mu_local
,
it
);
}
mu_local
*=
rn
;
if
(
lane
==
0
){
mu_ptr
[
row
]
=
mu_local
;
}
compute_t
var_local
=
0.
f
;
#pragma unroll
for
(
int
it
=
0
;
it
<
LDGS
;
it
++
)
{
#pragma unroll
for
(
int
jt
=
0
;
jt
<
NUM_ELTS
;
jt
++
)
{
compute_t
diff
=
xf
[
it
*
NUM_ELTS
+
jt
]
-
mu_local
;
var_local
+=
diff
*
diff
;
}
}
#pragma unroll
for
(
int
it
=
1
;
it
<
THREADS_PER_WARP
;
it
*=
2
)
{
var_local
+=
__shfl_xor_sync
(
uint32_t
(
-
1
),
var_local
,
it
);
}
compute_t
rsigma
=
rsqrtf
(
var_local
*
rn
+
epsilon
);
if
(
lane
==
0
){
rs_ptr
[
row
]
=
rsigma
;
}
#pragma unroll
for
(
int
it
=
0
;
it
<
LDGS
;
it
++
)
{
#pragma unroll
for
(
int
jt
=
0
;
jt
<
NUM_ELTS
;
jt
++
)
{
base_t
tmp
=
(
rsigma
*
(
xf
[
it
*
NUM_ELTS
+
jt
]
-
mu_local
));
x
[
it
].
data
.
elt
[
jt
]
=
gamma
[
it
].
data
.
elt
[
jt
]
*
tmp
+
beta
[
it
].
data
.
elt
[
jt
];
}
}
#pragma unroll
for
(
int
it
=
0
,
col
=
c
;
it
<
LDGS
;
it
++
)
{
x
[
it
].
store_to
(
y_ptr
+
row
*
BYTES_PER_ROW
+
col
*
BYTES_PER_LDG
);
col
+=
THREADS_PER_ROW
;
}
}
}
template
<
typename
scalar_t
>
void
launch
(
at
::
Tensor
&
y
,
// BxSxhidden_size
at
::
Tensor
&
mu
,
at
::
Tensor
&
rsigma
,
const
at
::
Tensor
&
x
,
// BxSxhidden_size
const
at
::
Tensor
&
gamma
,
const
at
::
Tensor
&
beta
,
const
float
epsilon
,
const
int
rows
,
const
int
cols
,
const
int
max_gridx
,
cudaStream_t
stream
){
if
(
cols
==
1024
)
{
using
Ktraits
=
Kernel_traits
<
scalar_t
,
1024
,
4
,
1
>
;
const
int
grid
=
std
::
min
<
int
>
(
DIVUP
(
rows
,
Ktraits
::
ROWS_PER_CTA
),
max_gridx
);
ln_fwd_kernel
<
Ktraits
><<<
grid
,
Ktraits
::
THREADS_PER_CTA
,
0
,
stream
>>>
(
y
.
data_ptr
(),
mu
.
data_ptr
(),
rsigma
.
data_ptr
(),
x
.
data_ptr
(),
gamma
.
data_ptr
(),
beta
.
data_ptr
(),
epsilon
,
rows
);
}
else
{
}
else
{
assert
(
false
&&
"Not implemented"
);
dim3
grid
(
Kernel_traits
::
CTAS_PER_ROW
*
ctas_per_col
);
dim3
block
(
Kernel_traits
::
THREADS_PER_CTA
);
void
*
params_
=
(
void
*
)
&
launch_params
.
params
;
cudaLaunchCooperativeKernel
((
void
*
)
kernel
,
grid
,
block
,
(
void
**
)
&
params_
,
Kernel_traits
::
SMEM_BYTES_FWD
,
stream
);
}
}
AT_CUDA_CHECK
(
cudaPeekAtLastError
());
}
}
void
ln_fwd_cuda
(
REGISTER_FWD_LAUNCHER
(
16384
,
fp32
,
fp32
,
fp32
,
fp32
,
2
,
1
,
4
,
16
);
at
::
Tensor
&
y
,
// BxSxhidden_size
REGISTER_FWD_LAUNCHER
(
16384
,
fp16
,
fp16
,
fp16
,
fp32
,
2
,
1
,
4
,
16
);
at
::
Tensor
&
mu
,
REGISTER_FWD_LAUNCHER
(
16384
,
fp16
,
fp32
,
fp16
,
fp32
,
2
,
1
,
4
,
16
);
at
::
Tensor
&
rsigma
,
REGISTER_FWD_LAUNCHER
(
16384
,
bf16
,
bf16
,
bf16
,
fp32
,
2
,
1
,
4
,
16
);
const
at
::
Tensor
&
x
,
// BxSxhidden_size
REGISTER_FWD_LAUNCHER
(
16384
,
bf16
,
fp32
,
bf16
,
fp32
,
2
,
1
,
4
,
16
);
const
at
::
Tensor
&
gamma
,
const
at
::
Tensor
&
beta
,
REGISTER_FWD_LAUNCHER
(
18432
,
fp32
,
fp32
,
fp32
,
fp32
,
4
,
1
,
4
,
16
);
const
float
epsilon
,
REGISTER_FWD_LAUNCHER
(
18432
,
fp16
,
fp16
,
fp16
,
fp32
,
2
,
1
,
4
,
16
);
const
int
rows
,
const
int
cols
,
REGISTER_FWD_LAUNCHER
(
18432
,
fp16
,
fp32
,
fp16
,
fp32
,
2
,
1
,
4
,
16
);
cudaStream_t
stream
REGISTER_FWD_LAUNCHER
(
18432
,
bf16
,
bf16
,
bf16
,
fp32
,
2
,
1
,
4
,
16
);
){
REGISTER_FWD_LAUNCHER
(
18432
,
bf16
,
fp32
,
bf16
,
fp32
,
4
,
1
,
4
,
16
);
const
auto
dtype
=
x
.
scalar_type
();
REGISTER_FWD_LAUNCHER
(
20480
,
fp32
,
fp32
,
fp32
,
fp32
,
2
,
1
,
4
,
16
);
const
auto
props
=
at
::
cuda
::
getCurrentDeviceProperties
();
REGISTER_FWD_LAUNCHER
(
20480
,
fp16
,
fp16
,
fp16
,
fp32
,
2
,
1
,
4
,
16
);
const
int
max_gridx
=
props
->
maxGridSize
[
0
];
REGISTER_FWD_LAUNCHER
(
20480
,
fp16
,
fp32
,
fp16
,
fp32
,
2
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
20480
,
bf16
,
bf16
,
bf16
,
fp32
,
2
,
1
,
4
,
16
);
//TODO
REGISTER_FWD_LAUNCHER
(
20480
,
bf16
,
fp32
,
bf16
,
fp32
,
2
,
1
,
4
,
16
);
// - Using dispatch macro costs 1% perf wtf?!?!
// - Tune FP32 warps
REGISTER_FWD_LAUNCHER
(
24576
,
fp32
,
fp32
,
fp32
,
fp32
,
4
,
1
,
4
,
16
);
// - Add more sizes
REGISTER_FWD_LAUNCHER
(
24576
,
fp16
,
fp16
,
fp16
,
fp32
,
2
,
1
,
4
,
16
);
if
(
dtype
==
torch
::
kFloat16
)
{
REGISTER_FWD_LAUNCHER
(
24576
,
fp16
,
fp32
,
fp16
,
fp32
,
2
,
1
,
4
,
16
);
launch
<
half
>
(
y
,
mu
,
rsigma
,
x
,
gamma
,
beta
,
epsilon
,
rows
,
cols
,
max_gridx
,
stream
);
REGISTER_FWD_LAUNCHER
(
24576
,
bf16
,
bf16
,
bf16
,
fp32
,
2
,
1
,
4
,
16
);
}
else
if
(
dtype
==
torch
::
kFloat32
)
{
REGISTER_FWD_LAUNCHER
(
24576
,
bf16
,
fp32
,
bf16
,
fp32
,
2
,
1
,
4
,
16
);
launch
<
float
>
(
y
,
mu
,
rsigma
,
x
,
gamma
,
beta
,
epsilon
,
rows
,
cols
,
max_gridx
,
stream
);
}
else
{
REGISTER_FWD_LAUNCHER
(
25600
,
fp32
,
fp32
,
fp32
,
fp32
,
4
,
1
,
4
,
4
);
assert
(
false
&&
"Not implemented"
);
REGISTER_FWD_LAUNCHER
(
25600
,
fp16
,
fp16
,
fp16
,
fp32
,
2
,
1
,
4
,
8
);
}
REGISTER_FWD_LAUNCHER
(
25600
,
fp16
,
fp32
,
fp16
,
fp32
,
4
,
1
,
4
,
4
);
REGISTER_FWD_LAUNCHER
(
25600
,
bf16
,
bf16
,
bf16
,
fp32
,
2
,
1
,
4
,
8
);
REGISTER_FWD_LAUNCHER
(
25600
,
bf16
,
fp32
,
bf16
,
fp32
,
4
,
1
,
4
,
4
);
REGISTER_FWD_LAUNCHER
(
30720
,
fp32
,
fp32
,
fp32
,
fp32
,
4
,
1
,
4
,
4
);
REGISTER_FWD_LAUNCHER
(
30720
,
fp16
,
fp16
,
fp16
,
fp32
,
4
,
1
,
4
,
4
);
REGISTER_FWD_LAUNCHER
(
30720
,
fp16
,
fp32
,
fp16
,
fp32
,
4
,
1
,
4
,
4
);
REGISTER_FWD_LAUNCHER
(
30720
,
bf16
,
bf16
,
bf16
,
fp32
,
4
,
1
,
4
,
4
);
REGISTER_FWD_LAUNCHER
(
30720
,
bf16
,
fp32
,
bf16
,
fp32
,
4
,
1
,
4
,
4
);
REGISTER_FWD_LAUNCHER
(
32768
,
fp32
,
fp32
,
fp32
,
fp32
,
4
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
32768
,
fp16
,
fp16
,
fp16
,
fp32
,
4
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
32768
,
fp16
,
fp32
,
fp16
,
fp32
,
4
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
32768
,
bf16
,
bf16
,
bf16
,
fp32
,
4
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
32768
,
bf16
,
fp32
,
bf16
,
fp32
,
4
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
40960
,
fp32
,
fp32
,
fp32
,
fp32
,
4
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
40960
,
fp16
,
fp16
,
fp16
,
fp32
,
4
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
40960
,
fp16
,
fp32
,
fp16
,
fp32
,
4
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
40960
,
bf16
,
bf16
,
bf16
,
fp32
,
4
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
40960
,
bf16
,
fp32
,
bf16
,
fp32
,
4
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
49152
,
fp32
,
fp32
,
fp32
,
fp32
,
8
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
49152
,
fp16
,
fp16
,
fp16
,
fp32
,
4
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
49152
,
fp16
,
fp32
,
fp16
,
fp32
,
4
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
49152
,
bf16
,
bf16
,
bf16
,
fp32
,
4
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
49152
,
bf16
,
fp32
,
bf16
,
fp32
,
4
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
65536
,
fp32
,
fp32
,
fp32
,
fp32
,
8
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
65536
,
fp16
,
fp16
,
fp16
,
fp32
,
8
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
65536
,
fp16
,
fp32
,
fp16
,
fp32
,
8
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
65536
,
bf16
,
bf16
,
bf16
,
fp32
,
8
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
65536
,
bf16
,
fp32
,
bf16
,
fp32
,
8
,
1
,
4
,
16
);
}
apex/contrib/csrc/layer_norm/ln_fwd_kernels.cuh
0 → 100644
View file @
db92ee13
#pragma once
#include "ln.h"
namespace
layer_norm
{
template
<
typename
Ktraits
>
__global__
__launch_bounds__
(
Ktraits
::
THREADS_PER_CTA
)
void
ln_fwd_kernel
(
FwdParams
params
)
{
enum
{
ROWS_PER_CTA
=
Ktraits
::
ROWS_PER_CTA
};
enum
{
WARPS_N
=
Ktraits
::
WARPS_N
};
enum
{
WARPS_M
=
Ktraits
::
WARPS_M
};
enum
{
THREADS_PER_ROW
=
Ktraits
::
THREADS_PER_ROW
};
enum
{
VEC_COLS_PER_LDG
=
Ktraits
::
VEC_COLS_PER_LDG
};
enum
{
BYTES_PER_ROW
=
Ktraits
::
BYTES_PER_ROW
};
enum
{
LDGS
=
Ktraits
::
LDGS
};
enum
{
NUM_ELTS
=
Ktraits
::
NUM_ELTS
};
enum
{
CTAS_PER_ROW
=
Ktraits
::
CTAS_PER_ROW
};
using
output_t
=
typename
Ktraits
::
output_t
;
using
index_t
=
typename
Ktraits
::
index_t
;
using
compute_t
=
typename
Ktraits
::
compute_t
;
using
Ivec
=
typename
Ktraits
::
Ivec
;
using
Ovec
=
typename
Ktraits
::
Ovec
;
using
Wvec
=
typename
Ktraits
::
Wvec
;
using
Cvec
=
typename
Ktraits
::
Cvec
;
using
Stats
=
typename
Ktraits
::
Stats
;
using
stats_t
=
typename
Stats
::
stats_t
;
extern
__shared__
char
smem_
[];
const
index_t
tidx
=
threadIdx
.
x
;
const
index_t
bidn
=
blockIdx
.
x
%
CTAS_PER_ROW
;
const
index_t
bidm
=
blockIdx
.
x
/
CTAS_PER_ROW
;
const
index_t
lane
=
tidx
%
THREADS_PER_WARP
;
const
index_t
warp
=
tidx
/
THREADS_PER_WARP
;
const
index_t
warp_m
=
warp
/
WARPS_N
;
const
index_t
warp_n
=
warp
%
WARPS_N
;
const
index_t
r
=
bidm
*
ROWS_PER_CTA
+
warp_m
;
const
index_t
c
=
bidn
*
THREADS_PER_ROW
+
warp_n
*
THREADS_PER_WARP
+
lane
;
Stats
stats
(
params
,
bidm
,
bidn
,
warp_m
,
warp_n
,
lane
,
smem_
);
compute_t
*
mu_ptr
=
static_cast
<
compute_t
*>
(
params
.
mu
);
compute_t
*
rs_ptr
=
static_cast
<
compute_t
*>
(
params
.
rs
);
Wvec
gamma
[
LDGS
];
Wvec
beta
[
LDGS
];
index_t
idx
=
c
;
#pragma unroll
for
(
int
it
=
0
;
it
<
LDGS
;
it
++
)
{
gamma
[
it
].
load_from
(
params
.
gamma
,
idx
);
beta
[
it
].
load_from
(
params
.
beta
,
idx
);
idx
+=
VEC_COLS_PER_LDG
;
}
constexpr
compute_t
rn
=
1.
f
/
compute_t
(
Ktraits
::
COLS
);
for
(
int
row
=
r
;
row
<
params
.
rows
;
row
+=
params
.
ctas_per_col
*
ROWS_PER_CTA
)
{
Ivec
x
[
LDGS
];
index_t
idx
=
row
*
Ktraits
::
VEC_COLS
+
c
;
compute_t
xf
[
LDGS
*
NUM_ELTS
];
#pragma unroll
for
(
int
it
=
0
;
it
<
LDGS
;
it
++
)
{
x
[
it
].
load_from
(
params
.
x
,
idx
);
#pragma unroll
for
(
int
jt
=
0
;
jt
<
NUM_ELTS
;
jt
++
)
{
compute_t
x_ij
=
compute_t
(
x
[
it
].
data
.
elt
[
jt
]);
xf
[
it
*
NUM_ELTS
+
jt
]
=
x_ij
;
}
idx
+=
VEC_COLS_PER_LDG
;
}
stats_t
s
=
stats
.
compute
(
xf
,
rn
);
compute_t
mu
=
layer_norm
::
Get
<
0
>::
of
<
stats_t
,
compute_t
>
(
s
);
compute_t
m2
=
layer_norm
::
Get
<
1
>::
of
<
stats_t
,
compute_t
>
(
s
);
if
(
bidn
==
0
&&
warp_n
==
0
&&
lane
==
0
)
{
mu_ptr
[
row
]
=
mu
;
}
compute_t
rs
=
rsqrtf
(
rn
*
m2
+
params
.
epsilon
);
if
(
bidn
==
0
&&
warp_n
==
0
&&
lane
==
0
)
{
rs_ptr
[
row
]
=
rs
;
}
Ovec
z
[
LDGS
];
idx
=
row
*
Ktraits
::
VEC_COLS
+
c
;
#pragma unroll
for
(
int
it
=
0
;
it
<
LDGS
;
it
++
)
{
#pragma unroll
for
(
int
jt
=
0
;
jt
<
NUM_ELTS
;
jt
++
)
{
output_t
y_ij
=
output_t
(
rs
*
(
xf
[
it
*
NUM_ELTS
+
jt
]
-
mu
));
output_t
g_ij
=
gamma
[
it
].
data
.
elt
[
jt
];
output_t
b_ij
=
beta
[
it
].
data
.
elt
[
jt
];
z
[
it
].
data
.
elt
[
jt
]
=
(
g_ij
*
y_ij
+
b_ij
);
}
z
[
it
].
store_to
(
params
.
z
,
idx
);
idx
+=
VEC_COLS_PER_LDG
;
}
}
}
}
// namespace layer_norm
apex/contrib/csrc/layer_norm/ln_kernel_traits.h
View file @
db92ee13
#pragma once
#pragma once
constexpr
uint32_t
THREADS_PER_WARP
=
32
;
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
dtype
,
int
COLS_
,
int
WARPS_M_
,
int
WARPS_N_
,
namespace
layer_norm
{
int
BYTES_PER_LDG_
=
16
>
template
<
struct
Kernel_traits
{
uint32_t
HIDDEN_SIZE_
,
enum
{
WARPS_M
=
WARPS_M_
};
typename
weight_t_
,
enum
{
WARPS_N
=
WARPS_N_
};
typename
input_t_
,
enum
{
COLS
=
COLS_
};
typename
output_t_
,
typename
compute_t_
,
typename
index_t_
,
uint32_t
THREADS_PER_CTA_
>
struct
Kernel_traits_base
{
using
weight_t
=
weight_t_
;
using
input_t
=
input_t_
;
using
output_t
=
output_t_
;
using
compute_t
=
compute_t_
;
using
index_t
=
index_t_
;
enum
{
HIDDEN_SIZE
=
HIDDEN_SIZE_
};
enum
{
THREADS_PER_CTA
=
THREADS_PER_CTA_
};
enum
{
THREADS_PER_WARP
=
32
};
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
uint32_t
HIDDEN_SIZE_
,
typename
weight_t_
,
typename
input_t_
,
typename
output_t_
,
typename
compute_t_
,
typename
index_t_
,
uint32_t
THREADS_PER_CTA_
,
uint32_t
BYTES_PER_LDG_
,
typename
Base
=
Kernel_traits_base
<
HIDDEN_SIZE_
,
weight_t_
,
input_t_
,
output_t_
,
compute_t_
,
index_t_
,
THREADS_PER_CTA_
>
>
struct
Kernel_traits_finalize
:
public
Base
{
enum
{
ROWS_PER_CTA
=
Base
::
THREADS_PER_CTA
/
Base
::
THREADS_PER_WARP
};
static_assert
((
int
)
ROWS_PER_CTA
<=
(
int
)
Base
::
THREADS_PER_WARP
);
// Bytes per global load from the input.
enum
{
BYTES_PER_LDG
=
BYTES_PER_LDG_
};
enum
{
BYTES_PER_LDG
=
BYTES_PER_LDG_
};
// Number of elements fetched by a global load.
enum
{
ELTS_PER_LDG
=
BYTES_PER_LDG
/
sizeof
(
compute_t_
)
};
// Bytes per global store of the weights.
enum
{
BYTES_PER_STG
=
ELTS_PER_LDG
*
sizeof
(
weight_t_
)
};
static_assert
(
sizeof
(
BYTES_PER_LDG
)
==
4
,
"Conflict-free smem transpose only implemented for 4B compute type!"
);
static_assert
(
Base
::
THREADS_PER_CTA
==
ROWS_PER_CTA
*
Base
::
THREADS_PER_WARP
,
"We assume one warp per row!"
);
// The total number of BYTES_PER_LDG-wide words in a hidden vector.
enum
{
COLS
=
HIDDEN_SIZE_
*
sizeof
(
compute_t_
)
/
BYTES_PER_LDG
};
static_assert
(
COLS
*
BYTES_PER_LDG
==
HIDDEN_SIZE_
*
sizeof
(
compute_t_
));
// Shared memory size to transpose the CTA result.
enum
{
SMEM_BYTES_TRANSPOSE
=
Base
::
THREADS_PER_CTA
*
BYTES_PER_LDG
};
// Shared memory size to coalsece the CTA result.
enum
{
SMEM_BYTES_OUTPUT
=
Base
::
THREADS_PER_WARP
*
BYTES_PER_LDG
};
// Shared memory requirement per CTA.
enum
{
SMEM_BYTES_PER_CTA
=
2
*
SMEM_BYTES_TRANSPOSE
+
2
*
SMEM_BYTES_OUTPUT
};
// The type of the reducer.
using
Reducer
=
layer_norm
::
Reducer
<
compute_t_
,
1
,
1
,
1
>
;
// Condition for the whole CTA to participate in syncthreads.
static_assert
(
COLS
%
Base
::
THREADS_PER_WARP
==
0
);
enum
{
CTAS
=
COLS
/
Base
::
THREADS_PER_WARP
};
};
////////////////////////////////////////////////////////////////////////////////////////////////////
using
Vec
=
Vec
<
dtype
,
BYTES_PER_LDG
>
;
using
vec_t
=
typename
Vec
::
vec_t
;
template
<
using
base_t
=
typename
Vec
::
base_t
;
typename
weight_t_
,
using
packed_t
=
typename
Vec
::
packed_t
;
typename
input_t_
,
using
compute_t
=
typename
Vec
::
compute_t
;
typename
output_t_
,
using
packed_compute_t
=
typename
Vec
::
packed_compute_t
;
typename
compute_t_
,
typename
index_t_
,
uint32_t
HIDDEN_SIZE_
,
uint32_t
CTAS_PER_ROW_
,
uint32_t
WARPS_M_
,
uint32_t
WARPS_N_
,
uint32_t
BYTES_PER_LDG_
=
16
,
typename
Base
=
Kernel_traits_base
<
HIDDEN_SIZE_
,
weight_t_
,
input_t_
,
output_t_
,
compute_t_
,
index_t_
,
WARPS_M_
*
WARPS_N_
*
THREADS_PER_WARP
>
>
struct
Kernel_traits
:
public
Base
{
using
input_t
=
typename
Base
::
input_t
;
using
weight_t
=
typename
Base
::
weight_t
;
using
compute_t
=
typename
Base
::
compute_t
;
using
output_t
=
typename
Base
::
output_t
;
using
index_t
=
typename
Base
::
index_t
;
enum
{
CTAS_PER_ROW
=
CTAS_PER_ROW_
};
enum
{
WARPS_M
=
WARPS_M_
};
enum
{
WARPS_N
=
WARPS_N_
};
enum
{
COLS
=
HIDDEN_SIZE_
};
enum
{
HIDDEN_SIZE
=
HIDDEN_SIZE_
};
enum
{
BYTES_PER_LDG
=
BYTES_PER_LDG_
};
enum
{
NUM_ELTS
=
BYTES_PER_LDG
/
sizeof
(
input_t
)
};
enum
{
THREADS_PER_ROW
=
WARPS_N
*
THREADS_PER_WARP
};
enum
{
THREADS_PER_ROW
=
WARPS_N
*
THREADS_PER_WARP
};
enum
{
THREADS_PER_CTA
=
WARPS_M
*
THREADS_PER_ROW
};
enum
{
THREADS_PER_CTA
=
WARPS_M
*
THREADS_PER_ROW
};
enum
{
ROWS_PER_CTA
=
WARPS_M
};
enum
{
ROWS_PER_CTA
=
WARPS_M
};
enum
{
BYTES_PER_ROW
=
COLS
*
sizeof
(
base
_t
)
};
enum
{
BYTES_PER_ROW
=
COLS
*
sizeof
(
input
_t
)
};
enum
{
BYTES_PER_ROW_PER_CTA
=
THREADS_PER_ROW
*
BYTES_PER_LDG
};
enum
{
BYTES_PER_ROW_PER_CTA
=
THREADS_PER_ROW
*
BYTES_PER_LDG
};
enum
{
SMEM_BYTES
=
ROWS_PER_CTA
*
COLS
*
sizeof
(
compute_t
)};
// Multi-row per CTA not supported for multi-CTA => no smem for WGRAD needed
enum
{
SMEM_BYTES_WGRAD
=
CTAS_PER_ROW
>
1
?
0
:
ROWS_PER_CTA
*
COLS
*
sizeof
(
compute_t
)
};
static_assert
(
WARPS_M
==
1
||
CTAS_PER_ROW
==
1
);
using
reduce_t
=
typename
layer_norm
::
TypeToVec2
<
compute_t
>::
Type
;
using
Reducer
=
layer_norm
::
Reducer
<
reduce_t
,
CTAS_PER_ROW
,
WARPS_M
,
WARPS_N
>
;
enum
{
SMEM_BYTES_DGRAD
=
Reducer
::
SMEM_BYTES
};
enum
{
SMEM_BYTES
=
SMEM_BYTES_DGRAD
+
SMEM_BYTES_WGRAD
};
using
Ivec
=
layer_norm
::
Vec
<
input_t
,
NUM_ELTS
>
;
using
Ovec
=
layer_norm
::
Vec
<
output_t
,
NUM_ELTS
>
;
using
Wvec
=
layer_norm
::
Vec
<
weight_t
,
NUM_ELTS
>
;
using
Cvec
=
layer_norm
::
Vec
<
compute_t
,
NUM_ELTS
>
;
enum
{
ELTS_PER_LDG
=
BYTES_PER_LDG
/
sizeof
(
input_t
)
};
// Assume that each thread can handle the same number of elements in the output and weights as in the input.
static_assert
(
sizeof
(
input_t
)
>=
sizeof
(
output_t
));
static_assert
(
sizeof
(
input_t
)
>=
sizeof
(
weight_t
));
// The number of columns fetched per load from input: one per thread.
enum
{
VEC_COLS_PER_LDG
=
CTAS_PER_ROW
*
THREADS_PER_ROW
};
// The total number of vectorized loads/stores per hidden vector.
enum
{
VEC_COLS
=
COLS
/
ELTS_PER_LDG
};
// The number of loads per thread for the input.
enum
{
LDGS
=
VEC_COLS
/
VEC_COLS_PER_LDG
};
static_assert
(
LDGS
*
VEC_COLS_PER_LDG
==
VEC_COLS
);
//static_assert(LDGS * BYTES_PER_ROW_PER_CTA * CTAS_PER_ROW == BYTES_PER_ROW, "");
using
Stats
=
layer_norm
::
Stats
<
compute_t
,
CTAS_PER_ROW
,
WARPS_M
,
WARPS_N
>
;
enum
{
SMEM_BYTES_FWD
=
Stats
::
SMEM_BYTES
};
};
};
////////////////////////////////////////////////////////////////////////////////////////////////////
}
// namespace layer_norm
apex/contrib/csrc/layer_norm/ln_utils.cuh
0 → 100644
View file @
db92ee13
#pragma once
#include <cassert>
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include "ln.h"
////////////////////////////////////////////////////////////////////////////////////////////////////
constexpr
uint32_t
THREADS_PER_WARP
=
32
;
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
void
check_cuda_
(
cudaError_t
status
,
const
char
*
file
,
int
line
)
{
if
(
status
!=
cudaSuccess
)
{
fprintf
(
stderr
,
"CUDA Error: %s %s %d
\n
"
,
cudaGetErrorString
(
status
),
file
,
line
);
exit
(
status
);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
#define CHECK_CUDA(ans) \
{ check_cuda_((ans), __FILE__, __LINE__); }
////////////////////////////////////////////////////////////////////////////////////////////////////
#define DIVUP(x, y) (((x) + ((y)-1)) / (y))
////////////////////////////////////////////////////////////////////////////////////////////////////
#define REGISTER_FWD_LAUNCHER(HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG) \
void ln_fwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE(LaunchParams<FwdParams> &launch_params, \
const bool configure_params) { \
launch_<WTYPE, ITYPE, OTYPE, CTYPE, uint32_t, HIDDEN_SIZE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG>( \
launch_params, configure_params); \
} \
static FwdRegistrar<WTYPE, ITYPE, OTYPE, CTYPE, HIDDEN_SIZE> reg_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \
ln_fwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE)
////////////////////////////////////////////////////////////////////////////////////////////////////
#define REGISTER_BWD_LAUNCHER( \
HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINALIZE) \
void ln_bwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE(LaunchParams<BwdParams> &launch_params, \
const bool configure_params) { \
launch_<WTYPE, \
ITYPE, \
OTYPE, \
CTYPE, \
uint32_t, \
HIDDEN_SIZE, \
CTAS_PER_ROW, \
WARPS_M, \
WARPS_N, \
BYTES_PER_LDG, \
BYTES_PER_LDG_FINALIZE>(launch_params, configure_params); \
} \
static BwdRegistrar<WTYPE, ITYPE, OTYPE, CTYPE, HIDDEN_SIZE> reg_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \
ln_bwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE)
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
float2
operator
+
(
const
float2
&
a
,
const
float2
&
b
){
return
{
a
.
x
+
b
.
x
,
a
.
y
+
b
.
y
};
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
void
operator
+=
(
float2
&
a
,
const
float2
&
b
){
a
.
x
+=
b
.
x
;
a
.
y
+=
b
.
y
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
T
>
struct
Sum
{
inline
__device__
Sum
(){}
inline
__device__
T
operator
()(
const
T
&
a
,
const
T
&
b
){
return
a
+
b
;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
T
>
inline
__device__
T
warp_shuffle_xor
(
const
T
&
x
,
uint32_t
idx
){
return
__shfl_xor_sync
(
uint32_t
(
-
1
),
x
,
idx
);
}
template
<
>
inline
__device__
float2
warp_shuffle_xor
<
float2
>
(
const
float2
&
x
,
uint32_t
idx
){
return
{
warp_shuffle_xor
(
x
.
x
,
idx
),
warp_shuffle_xor
(
x
.
y
,
idx
)
};
}
template
<
typename
T
>
inline
__device__
T
warp_shuffle_down
(
const
T
&
x
,
uint32_t
idx
){
return
__shfl_down_sync
(
uint32_t
(
-
1
),
x
,
idx
);
}
template
<
>
inline
__device__
float2
warp_shuffle_down
<
float2
>
(
const
float2
&
x
,
uint32_t
idx
){
return
{
warp_shuffle_down
(
x
.
x
,
idx
),
warp_shuffle_down
(
x
.
y
,
idx
)
};
}
////////////////////////////////////////////////////////////////////////////////////////////////////
namespace
layer_norm
{
////////////////////////////////////////////////////////////////////////////////////////////////////
struct
uint16
{
uint4
u
;
uint4
v
;
uint4
s
;
uint4
t
;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
struct
uint8
{
uint4
u
;
uint4
v
;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
BYTES
>
struct
BytesToType
{};
template
<
>
struct
BytesToType
<
64
>
{
using
Type
=
uint16
;
static_assert
(
sizeof
(
Type
)
==
64
);
};
template
<
>
struct
BytesToType
<
32
>
{
using
Type
=
uint8
;
static_assert
(
sizeof
(
Type
)
==
32
);
};
template
<
>
struct
BytesToType
<
16
>
{
using
Type
=
uint4
;
static_assert
(
sizeof
(
Type
)
==
16
);
};
template
<
>
struct
BytesToType
<
8
>
{
using
Type
=
uint64_t
;
static_assert
(
sizeof
(
Type
)
==
8
);
};
template
<
>
struct
BytesToType
<
4
>
{
using
Type
=
uint32_t
;
static_assert
(
sizeof
(
Type
)
==
4
);
};
template
<
>
struct
BytesToType
<
2
>
{
using
Type
=
uint16_t
;
static_assert
(
sizeof
(
Type
)
==
2
);
};
template
<
>
struct
BytesToType
<
1
>
{
using
Type
=
uint8_t
;
static_assert
(
sizeof
(
Type
)
==
1
);
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
T
>
struct
TypeToVec2
{};
template
<
>
struct
TypeToVec2
<
float
>
{
using
Type
=
float2
;
};
template
<
>
struct
TypeToVec2
<
half
>
{
using
Type
=
half2
;
};
template
<
>
struct
TypeToVec2
<
nv_bfloat16
>
{
using
Type
=
nv_bfloat162
;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
INDEX
>
struct
Get
{
template
<
typename
T
,
typename
R
>
static
inline
__device__
R
of
(
const
T
&
vec
);
};
template
<
>
template
<
typename
T
,
typename
R
>
inline
__device__
R
Get
<
0
>::
of
(
const
T
&
vec
)
{
return
vec
.
x
;
}
template
<
>
template
<
typename
T
,
typename
R
>
inline
__device__
R
Get
<
1
>::
of
(
const
T
&
vec
)
{
return
vec
.
y
;
}
template
<
>
template
<
typename
T
,
typename
R
>
inline
__device__
R
Get
<
2
>::
of
(
const
T
&
vec
)
{
return
vec
.
z
;
}
template
<
>
template
<
typename
T
,
typename
R
>
inline
__device__
R
Get
<
3
>::
of
(
const
T
&
vec
)
{
return
vec
.
w
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Src
,
typename
Dst
>
struct
Converter
{
static
inline
__device__
Dst
convert
(
const
Src
&
from
)
{
return
Dst
(
from
);
}
};
template
<
>
struct
Converter
<
float2
,
half2
>
{
static
inline
__device__
half2
convert
(
const
float2
&
x
)
{
return
__float22half2_rn
(
x
);
}
};
template
<
>
struct
Converter
<
float2
,
nv_bfloat162
>
{
static
inline
__device__
nv_bfloat162
convert
(
const
float2
&
x
)
{
#if __CUDA_ARCH__ >= 800
return
__float22bfloat162_rn
(
x
);
#else
union
{
nv_bfloat162
raw
;
nv_bfloat16
x
;
nv_bfloat16
y
;
}
tmp
;
tmp
.
x
=
__float2bfloat16_rn
(
x
.
x
);
tmp
.
y
=
__float2bfloat16_rn
(
x
.
y
);
return
tmp
.
raw
;
#endif
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
T
>
struct
Zeros
{
static
inline
__device__
T
get
()
{
return
T
(
0.
f
);
}
};
template
<
>
struct
Zeros
<
float2
>
{
static
inline
__device__
float2
get
()
{
return
make_float2
(
0.
f
,
0.
f
);
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Elt_type
,
uint32_t
NUM_ELT
>
struct
Vec
{
enum
{
BYTES
=
NUM_ELT
*
sizeof
(
Elt_type
)
};
using
Vec_type
=
typename
BytesToType
<
BYTES
>::
Type
;
using
Alias_type
=
union
{
Vec_type
vec
;
Elt_type
elt
[
NUM_ELT
];
};
Alias_type
data
;
template
<
typename
S
>
inline
__device__
void
to
(
Vec
<
S
,
NUM_ELT
>
&
other
)
{
#pragma unroll
for
(
int
it
=
0
;
it
<
NUM_ELT
;
it
++
)
{
other
.
data
.
elt
[
it
]
=
S
(
this
->
data
.
elt
[
it
]);
}
}
template
<
typename
Op
>
inline
__device__
void
assign
(
const
Op
&
op
)
{
#pragma unroll
for
(
int
it
=
0
;
it
<
NUM_ELT
;
it
++
)
{
this
->
data
.
elt
[
it
]
=
op
(
it
);
}
}
inline
__device__
void
load_from
(
const
void
*
base_ptr
,
const
size_t
idx
)
{
this
->
data
.
vec
=
static_cast
<
const
Vec_type
*>
(
base_ptr
)[
idx
];
}
inline
__device__
void
store_to
(
void
*
base_ptr
,
const
size_t
idx
)
{
static_cast
<
Vec_type
*>
(
base_ptr
)[
idx
]
=
this
->
data
.
vec
;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
uint32_t
CTAS_PER_ROW
>
struct
InterCTASync
{
template
<
typename
Params
>
inline
__device__
InterCTASync
(
Params
&
params
,
uint32_t
bidm
,
uint32_t
bidn
)
:
phase_counter_
(
0
)
,
b0_
(
params
.
barrier
+
bidm
)
// The barrier for this group of CTAs.
,
b1_
(
params
.
barrier
+
bidm
+
params
.
ctas_per_col
)
// The barrier for this group of CTAs.
{
// BARRIERS ARE ASSUMED TO BE INITIALIZED TO 0!
}
inline
__device__
void
spin_wait_
(
int
*
barrier
,
int
step
,
int
expected
)
{
asm
volatile
(
"red.release.gpu.global.add.s32 [%0], %1;"
::
"l"
(
barrier
),
"r"
(
step
));
for
(
int
found
=
-
1
;
found
!=
expected
;
)
{
asm
volatile
(
"ld.global.acquire.gpu.b32 %0, [%1];"
:
"=r"
(
found
)
:
"l"
(
barrier
));
}
}
inline
__device__
void
sync
(){
// ALL THREADS MUST ENTER!
// We switch barrier every iteration.
int
*
barrier
=
phase_counter_
&
0x1
?
b1_
:
b0_
;
// We decrement every other iteration.
bool
dec
=
phase_counter_
&
0x2
;
int
step
=
dec
?
-
1
:
1
;
int
expected
=
dec
?
0
:
CTAS_PER_ROW
;
// There are only 4 phases: up/down for b0/b1.
phase_counter_
=
(
phase_counter_
+
1
)
&
0x3
;
if
(
threadIdx
.
x
==
0
)
{
spin_wait_
(
barrier
,
step
,
expected
);
}
// CTA waits for thread 0
__syncthreads
();
}
int
phase_counter_
;
int
*
b0_
;
int
*
b1_
;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
T
,
uint32_t
CTAS_PER_ROW
,
uint32_t
WARPS_M
,
uint32_t
WARPS_N
>
struct
Reducer
:
public
Reducer
<
T
,
1
,
WARPS_M
,
WARPS_N
>
{
using
InterCTASync
=
InterCTASync
<
CTAS_PER_ROW
>
;
using
Base
=
Reducer
<
T
,
1
,
WARPS_M
,
WARPS_N
>
;
using
Type
=
typename
Base
::
Type
;
enum
{
SMEM_BYTES
=
Base
::
SMEM_BYTES
};
enum
{
WS_BARRIER_BYTES
=
2
*
sizeof
(
int
)
};
enum
{
WS_DATA_BYTES
=
WARPS_M
*
CTAS_PER_ROW
*
sizeof
(
T
)
};
// size of the barriers + temporary result per CTA (multiply with CTAS_PER_ROW to get total)
enum
{
WORKSPACE_BYTES_PER_GROUP
=
Base
::
WORKSPACE_BYTES_PER_GROUP
+
WS_BARRIER_BYTES
+
WS_DATA_BYTES
};
template
<
typename
Params
>
inline
__device__
Reducer
(
Params
&
params
,
uint32_t
bidm
,
uint32_t
bidn
,
uint32_t
warp_m
,
uint32_t
warp_n
,
uint32_t
lane
,
void
*
smem
)
:
Base
(
params
,
bidm
,
bidn
,
warp_m
,
warp_n
,
lane
,
smem
)
,
inter_cta_
(
params
,
bidm
,
bidn
)
,
bidn_
(
bidn
)
// CTA id within the group.
,
w0_
(
static_cast
<
T
*>
(
params
.
workspace
)
+
(
bidm
*
WARPS_M
+
warp_m
)
*
CTAS_PER_ROW
)
,
w1_
(
w0_
+
params
.
ctas_per_col
*
WARPS_M
*
CTAS_PER_ROW
)
{
}
template
<
typename
Op
>
inline
__device__
T
allreduce
(
T
data
,
Op
&
op
)
{
data
=
Base
::
reduce
(
data
,
op
);
// We switch workspace every iteration.
T
*
workspace
=
inter_cta_
.
phase_counter_
&
0x1
?
w1_
:
w0_
;
// Warp leaders 0 hold the CTA-local results.
if
(
this
->
warp_n_
==
0
&&
this
->
lane_
==
0
)
{
workspace
[
bidn_
]
=
data
;
}
inter_cta_
.
sync
();
static_assert
(
CTAS_PER_ROW
<=
32
);
T
total
=
Zeros
<
T
>::
get
();
if
(
this
->
lane_
<
CTAS_PER_ROW
){
total
=
workspace
[
this
->
lane_
];
}
total
=
Reducer
<
T
,
1
,
1
,
1
>::
allreduce_
(
total
,
op
);
return
total
;
}
InterCTASync
inter_cta_
;
T
*
w0_
;
T
*
w1_
;
int
bidn_
;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
T
,
uint32_t
WARPS_M
>
struct
Reducer
<
T
,
1
,
WARPS_M
,
1
>
{
using
Type
=
T
;
enum
{
SMEM_BYTES
=
0
};
enum
{
WORKSPACE_BYTES_PER_GROUP
=
0
};
enum
{
THREADS_PER_WARP
=
32
};
template
<
typename
Params
>
inline
__device__
Reducer
(
Params
&
params
,
uint32_t
bidm
,
uint32_t
bidn
,
uint32_t
warp_m
,
uint32_t
warp_n
,
uint32_t
lane
,
void
*
smem
)
:
warp_n_
(
warp_n
)
,
lane_
(
lane
)
{
}
template
<
typename
Op
>
static
inline
__device__
T
allreduce_
(
T
data
,
Op
&
op
)
{
#pragma unroll
for
(
int
it
=
1
;
it
<
THREADS_PER_WARP
;
it
*=
2
)
{
data
=
op
(
data
,
warp_shuffle_xor
(
data
,
it
));
}
return
data
;
}
template
<
typename
Op
>
inline
__device__
T
allreduce
(
T
data
,
Op
&
op
)
{
return
allreduce_
(
data
,
op
);
}
template
<
typename
Op
>
inline
__device__
T
reduce
(
T
data
,
Op
&
op
){
// only lane 0 holds the result!
#pragma unroll
for
(
int
it
=
THREADS_PER_WARP
/
2
;
it
>
0
;
it
/=
2
)
{
data
=
op
(
data
,
warp_shuffle_down
(
data
,
it
));
}
return
data
;
}
int
warp_n_
;
int
lane_
;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
T
,
uint32_t
WARPS_M
,
uint32_t
WARPS_N
>
struct
Reducer
<
T
,
1
,
WARPS_M
,
WARPS_N
>
:
public
Reducer
<
T
,
1
,
WARPS_M
,
1
>
{
using
Base
=
Reducer
<
T
,
1
,
WARPS_M
,
1
>
;
using
Type
=
T
;
enum
{
SMEM_BYTES
=
Base
::
SMEM_BYTES
+
WARPS_M
*
WARPS_N
*
sizeof
(
T
)
*
2
};
enum
{
WORKSPACE_BYTES_PER_GROUP
=
0
};
enum
{
THREADS_PER_WARP
=
32
};
template
<
typename
Params
>
inline
__device__
Reducer
(
Params
&
params
,
uint32_t
bidm
,
uint32_t
bidn
,
uint32_t
warp_m
,
uint32_t
warp_n
,
uint32_t
lane
,
void
*
smem
)
:
Base
(
params
,
bidm
,
bidn
,
warp_m
,
warp_n
,
lane
,
smem
)
,
use0_
(
true
)
{
smem0_
=
&
static_cast
<
T
*>
(
smem
)[
warp_m
*
WARPS_N
];
smem1_
=
smem0_
+
WARPS_M
*
WARPS_N
;
}
template
<
typename
Op
>
inline
__device__
T
allreduce
(
T
data
,
Op
&
op
)
{
T
*
smem
=
use0_
?
smem0_
:
smem1_
;
use0_
=
!
use0_
;
data
=
Base
::
reduce
(
data
,
op
);
if
(
this
->
lane_
==
0
)
{
smem
[
this
->
warp_n_
]
=
data
;
}
__syncthreads
();
T
out
=
Zeros
<
T
>::
get
();
#pragma unroll
for
(
int
it
=
0
;
it
<
WARPS_N
;
it
++
)
{
out
=
op
(
out
,
smem
[
it
]);
}
return
out
;
}
template
<
typename
Op
>
inline
__device__
T
reduce
(
T
data
,
Op
&
op
)
{
T
*
smem
=
use0_
?
smem0_
:
smem1_
;
use0_
=
!
use0_
;
// only intra-CTA group leader holds the result!
data
=
Base
::
reduce
(
data
,
op
);
if
(
this
->
lane_
==
0
)
{
smem
[
this
->
warp_n_
]
=
data
;
}
__syncthreads
();
T
out
=
Zeros
<
T
>::
get
();
if
(
this
->
warp_n_
==
0
&&
this
->
lane_
==
0
)
{
#pragma unroll
for
(
int
it
=
0
;
it
<
WARPS_N
;
it
++
)
{
out
=
op
(
out
,
smem
[
it
]);
}
}
return
out
;
}
T
*
smem0_
;
T
*
smem1_
;
bool
use0_
;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
T
>
inline
__device__
void
warp_chan_upd_dynamic
(
T
&
m_a
,
T
&
m2_a
,
T
&
n_a
,
int
num_active
){
//Assume at least leftmost is valid and init: step = next_pow2(num_active) / 2 (might get NaN otherwise)
int
highest_bit_set
=
(
8
*
sizeof
(
num_active
))
-
__clz
(
num_active
-
1
);
#pragma unroll
for
(
int
step
=
(
1
<<
(
highest_bit_set
-
1
));
step
>
0
;
step
/=
2
)
{
// Exchange
T
n_b
=
warp_shuffle_down
(
n_a
,
step
);
T
m_b
=
warp_shuffle_down
(
m_a
,
step
);
T
m2_b
=
warp_shuffle_down
(
m2_a
,
step
);
// Update
const
T
n_ab
=
n_a
+
n_b
;
// We can handle one of them being 0, not both.
const
T
rn_ab
=
1.
f
/
n_ab
;
// Might have different n per thread, otherwise this would simplify :(
const
T
delta
=
m_a
-
m_b
;
const
float
m2_ab
=
m2_a
+
m2_b
+
delta
*
delta
*
n_a
*
n_b
*
rn_ab
;
const
float
m_ab
=
(
n_a
*
m_a
+
n_b
*
m_b
)
*
rn_ab
;
n_a
=
n_ab
;
m_a
=
m_ab
;
m2_a
=
m2_ab
;
}
// Intra-warp broadcast (only lane 0 has valid stats).
m_a
=
__shfl_sync
(
uint32_t
(
-
1
),
m_a
,
0
);
m2_a
=
__shfl_sync
(
uint32_t
(
-
1
),
m2_a
,
0
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
T
,
uint32_t
CTAS_PER_ROW
,
uint32_t
WARPS_M
,
uint32_t
WARPS_N
>
struct
Stats
{
// This could be done generically with the Reducer. But then we would have to exchange 3 instead of 2 fields.
using
InterCTASync
=
InterCTASync
<
CTAS_PER_ROW
>
;
using
BlockStats
=
Stats
<
T
,
1
,
WARPS_M
,
WARPS_N
>
;
using
stats_t
=
typename
BlockStats
::
stats_t
;
enum
{
SMEM_BYTES
=
BlockStats
::
SMEM_BYTES
};
template
<
typename
Params
>
inline
__device__
Stats
(
Params
&
params
,
uint32_t
bidm
,
uint32_t
bidn
,
uint32_t
warp_m
,
uint32_t
warp_n
,
uint32_t
lane
,
void
*
smem
)
:
inter_cta_
(
params
,
bidm
,
bidn
)
,
block_stats_
(
params
,
bidm
,
bidn
,
warp_m
,
warp_n
,
lane
,
smem
)
,
bidn_
(
bidn
)
// CTA id within the group.
,
w0_
(
static_cast
<
stats_t
*>
(
params
.
workspace
)
+
(
bidm
*
WARPS_M
+
warp_m
)
*
CTAS_PER_ROW
)
,
w1_
(
w0_
+
params
.
ctas_per_col
*
WARPS_M
*
CTAS_PER_ROW
)
,
warp_n_
(
warp_n
)
,
lane_
(
lane
)
{
}
template
<
uint32_t
N
>
inline
__device__
stats_t
compute
(
const
T
(
&
elts
)[
N
],
const
T
rn
)
{
constexpr
T
ELTS_PER_ROW_PER_CTA
=
N
*
WARPS_N
*
THREADS_PER_WARP
;
// TODO rn is not really needed here..
constexpr
T
block_rn
=
1.
f
/
T
(
ELTS_PER_ROW_PER_CTA
);
stats_t
block_stats
=
block_stats_
.
compute
(
elts
,
block_rn
);
stats_t
*
workspace
=
inter_cta_
.
phase_counter_
&
0x1
?
w1_
:
w0_
;
if
(
warp_n_
==
0
&&
lane_
==
0
)
{
workspace
[
bidn_
]
=
block_stats
;
}
// Wait for all CTAS_PER_ROW CTAS in the group to have written their result.
inter_cta_
.
sync
();
T
n
=
Zeros
<
T
>::
get
();
T
m
=
Zeros
<
T
>::
get
();
T
m2
=
Zeros
<
T
>::
get
();
// Assume CTA group size in N less than 32, such that we can finalize with a single warp.
static_assert
(
CTAS_PER_ROW
<=
32
);
// Every warp does the final reduction locally.
if
(
lane_
<
CTAS_PER_ROW
)
{
stats_t
result
=
workspace
[
lane_
];
n
=
ELTS_PER_ROW_PER_CTA
;
m
=
layer_norm
::
Get
<
0
>::
of
<
stats_t
,
T
>
(
result
);
m2
=
layer_norm
::
Get
<
1
>::
of
<
stats_t
,
T
>
(
result
);
}
warp_chan_upd_dynamic
(
m
,
m2
,
n
,
CTAS_PER_ROW
);
return
{
m
,
m2
};
}
InterCTASync
inter_cta_
;
BlockStats
block_stats_
;
stats_t
*
w0_
;
stats_t
*
w1_
;
int
bidn_
;
int
warp_n_
;
int
lane_
;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
T
,
uint32_t
WARPS_M
,
uint32_t
WARPS_N
>
struct
Stats
<
T
,
1
,
WARPS_M
,
WARPS_N
>
{
using
WarpStats
=
Stats
<
T
,
1
,
WARPS_M
,
1
>
;
using
stats_t
=
typename
WarpStats
::
stats_t
;
enum
{
SMEM_BYTES
=
WARPS_M
*
WARPS_N
*
sizeof
(
stats_t
)
*
2
};
template
<
typename
Params
>
inline
__device__
Stats
(
Params
&
params
,
uint32_t
bidm
,
uint32_t
bidn
,
uint32_t
warp_m
,
uint32_t
warp_n
,
uint32_t
lane
,
void
*
smem
)
:
warp_stats_
(
params
,
bidm
,
bidn
,
warp_m
,
warp_n
,
lane
,
smem
)
,
use0_
(
true
)
{
smem0_
=
static_cast
<
stats_t
*>
(
smem
)
+
warp_m
*
WARPS_N
;
smem1_
=
smem0_
+
WARPS_M
*
WARPS_N
;
}
template
<
uint32_t
N
>
inline
__device__
stats_t
compute
(
const
T
(
&
elts
)[
N
],
const
T
rn
)
{
stats_t
*
smem
=
use0_
?
smem0_
:
smem1_
;
use0_
=
!
use0_
;
// Compute warp local for all WARPS_N
constexpr
T
warp_rn
=
1.
f
/
T
(
N
*
THREADS_PER_WARP
);
stats_t
warp_stats
=
warp_stats_
.
compute
(
elts
,
warp_rn
);
//Each warp warp leader stores its stats
const
auto
warp_n
=
warp_stats_
.
reducer_
.
warp_n_
;
const
auto
lane
=
warp_stats_
.
reducer_
.
lane_
;
if
(
lane
==
0
)
{
smem
[
warp_n
]
=
warp_stats
;
}
__syncthreads
();
T
n
=
Zeros
<
T
>::
get
();
T
m
=
Zeros
<
T
>::
get
();
T
m2
=
Zeros
<
T
>::
get
();
// Assume that there are less than 32 warps, such that we can finalize with a single warp
static_assert
(
WARPS_N
<=
32
);
if
(
lane
<
WARPS_N
){
stats_t
result
=
smem
[
lane
];
n
=
N
*
THREADS_PER_WARP
;
m
=
layer_norm
::
Get
<
0
>::
of
<
stats_t
,
T
>
(
result
);
m2
=
layer_norm
::
Get
<
1
>::
of
<
stats_t
,
T
>
(
result
);
}
warp_chan_upd_dynamic
(
m
,
m2
,
n
,
WARPS_N
);
return
{
m
,
m2
};
}
WarpStats
warp_stats_
;
stats_t
*
smem0_
;
stats_t
*
smem1_
;
bool
use0_
;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
T
,
uint32_t
WARPS_M
>
struct
Stats
<
T
,
1
,
WARPS_M
,
1
>
{
using
stats_t
=
typename
TypeToVec2
<
T
>::
Type
;
// The simple Warp reducer.
using
Reducer
=
Reducer
<
T
,
1
,
WARPS_M
,
1
>
;
enum
{
SMEM_BYTES
=
0
};
template
<
typename
Params
>
inline
__device__
Stats
(
Params
&
params
,
uint32_t
bidm
,
uint32_t
bidn
,
uint32_t
warp_m
,
uint32_t
warp_n
,
uint32_t
lane
,
void
*
smem
)
:
reducer_
(
params
,
bidm
,
bidn
,
warp_m
,
warp_n
,
lane
,
smem
)
{
}
template
<
uint32_t
N
>
inline
__device__
stats_t
compute
(
const
T
(
&
elts
)[
N
],
const
T
rn
)
{
auto
sum
=
Sum
<
T
>
();
T
m
=
Zeros
<
T
>::
get
();
#pragma unroll
for
(
int
it
=
0
;
it
<
N
;
it
++
)
{
m
+=
elts
[
it
];
}
m
=
reducer_
.
allreduce
(
m
,
sum
)
*
rn
;
T
m2
=
Zeros
<
T
>::
get
();
#pragma unroll
for
(
int
it
=
0
;
it
<
N
;
it
++
)
{
T
diff
=
(
elts
[
it
]
-
m
);
m2
+=
diff
*
diff
;
}
m2
=
reducer_
.
allreduce
(
m2
,
sum
);
return
{
m
,
m2
};
}
Reducer
reducer_
;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
}
// namespace layer_norm
apex/contrib/csrc/layer_norm/utils.cuh
deleted
100644 → 0
View file @
d150afdc
#pragma once
#include "torch/extension.h"
#include <ATen/cuda/Exceptions.h> // for CUDNN_CHECK
#define DIVUP(x, y) (((x) + ((y)-1)) / (y))
#define DISPATCH_FLOAT_AND_HALF(TYPE, NAME, ...) \
[&] { \
const auto &the_type = TYPE; \
/* don't use TYPE again in case it is an expensive or side-effect op */
\
at::ScalarType _st = ::detail::scalar_type(the_type); \
switch (_st) { \
AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__) \
AT_PRIVATE_CASE_TYPE(at::ScalarType::Half, at::Half, __VA_ARGS__) \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \
} \
}()
template
<
int
Bytes
>
struct
Vec_type
{};
template
<
>
struct
Vec_type
<
16
>
{
using
Type
=
uint4
;
static
__device__
inline
Type
zero
()
{
return
make_uint4
(
0
,
0
,
0
,
0
);
}
};
template
<
>
struct
Vec_type
<
8
>
{
using
Type
=
uint2
;
static
__device__
inline
Type
zero
()
{
return
make_uint2
(
0
,
0
);
}
};
template
<
>
struct
Vec_type
<
4
>
{
using
Type
=
uint32_t
;
static
__device__
inline
Type
zero
()
{
return
0
;
}
};
template
<
>
struct
Vec_type
<
2
>
{
using
Type
=
uint16_t
;
static
__device__
inline
Type
zero
()
{
return
0
;
}
};
template
<
typename
T
>
struct
TypeInfo
{
using
base_t
=
T
;
using
packed_t
=
T
;
using
compute_t
=
float
;
using
packed_compute_t
=
float
;
};
template
<
>
struct
TypeInfo
<
half
>
{
using
base_t
=
half
;
using
packed_t
=
half2
;
using
compute_t
=
float
;
using
packed_compute_t
=
float2
;
};
template
<
typename
dtype
,
int
Bytes
>
struct
Vec
{
using
base_t
=
typename
TypeInfo
<
dtype
>::
base_t
;
using
packed_t
=
typename
TypeInfo
<
dtype
>::
packed_t
;
using
compute_t
=
typename
TypeInfo
<
dtype
>::
compute_t
;
using
packed_compute_t
=
typename
TypeInfo
<
dtype
>::
packed_compute_t
;
static_assert
(
Bytes
%
sizeof
(
base_t
)
==
0
,
""
);
static_assert
(
Bytes
%
sizeof
(
packed_t
)
==
0
,
""
);
enum
{
BYTES_PER_THREAD
=
Bytes
};
enum
{
NUM_ELTS
=
Bytes
/
sizeof
(
base_t
)
};
enum
{
NUM_PACKED
=
Bytes
/
sizeof
(
packed_t
)
};
using
vec_t
=
typename
Vec_type
<
Bytes
>::
Type
;
using
store_t
=
union
{
vec_t
raw
;
base_t
elt
[
NUM_ELTS
];
packed_t
packed
[
NUM_PACKED
];
};
store_t
data
;
__device__
Vec
()
{
data
.
raw
=
Vec_type
<
Bytes
>::
zero
();
}
__device__
inline
void
load_from
(
const
char
*
ptr
)
{
data
.
raw
=
*
reinterpret_cast
<
const
vec_t
*>
(
ptr
);
}
__device__
inline
void
load_or_zero
(
const
char
*
ptr
,
const
bool
is_valid
)
{
data
.
raw
=
is_valid
?
*
reinterpret_cast
<
const
vec_t
*>
(
ptr
)
:
Vec_type
<
Bytes
>::
zero
();
}
__device__
inline
void
store_to
(
char
*
ptr
)
const
{
*
reinterpret_cast
<
vec_t
*>
(
ptr
)
=
data
.
raw
;
}
__device__
inline
void
store_valid
(
char
*
ptr
,
const
bool
is_valid
)
const
{
if
(
is_valid
)
*
reinterpret_cast
<
vec_t
*>
(
ptr
)
=
data
.
raw
;
}
};
apex/contrib/csrc/multihead_attn/additive_masked_softmax_dropout_cpp.cpp
View file @
db92ee13
#include <torch/extension.h>
#include <cuda_fp16.h>
#include <cuda_fp16.h>
#include <torch/extension.h>
#include <vector>
#include <vector>
namespace
multihead_attn
{
namespace
multihead_attn
{
namespace
fused_softmax
{
namespace
fused_softmax
{
namespace
additive_mask_softmax_dropout
{
namespace
additive_mask_softmax_dropout
{
std
::
vector
<
torch
::
Tensor
>
fwd_cuda
(
std
::
vector
<
torch
::
Tensor
>
fwd_cuda
(
bool
is_training
,
int
heads
,
bool
is_training
,
torch
::
Tensor
const
&
input
,
int
heads
,
const
half
*
pad_mask
,
float
dropout_prob
);
torch
::
Tensor
const
&
input
,
const
half
*
pad_mask
,
float
dropout_prob
);
torch
::
Tensor
bwd_cuda
(
torch
::
Tensor
bwd_cuda
(
int
heads
,
torch
::
Tensor
const
&
output_grads
,
int
heads
,
torch
::
Tensor
const
&
softmax_results
,
torch
::
Tensor
const
&
output_grads
,
torch
::
Tensor
const
&
dropout_mask
,
float
dropout_prob
);
torch
::
Tensor
const
&
softmax_results
,
torch
::
Tensor
const
&
dropout_mask
,
float
dropout_prob
);
// C++ interface
// C++ interface
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CUDA(x) \
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
#define CHECK_CONTIGUOUS(x) \
AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
std
::
vector
<
torch
::
Tensor
>
fwd
(
std
::
vector
<
torch
::
Tensor
>
fwd
(
bool
use_mask
,
bool
is_training
,
int
heads
,
bool
use_mask
,
torch
::
Tensor
const
&
input
,
bool
is_training
,
torch
::
Tensor
const
&
pad_mask
,
int
heads
,
float
dropout_prob
)
{
torch
::
Tensor
const
&
input
,
torch
::
Tensor
const
&
pad_mask
,
float
dropout_prob
)
{
AT_ASSERTM
(
input
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
input
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
input
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
input
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
if
(
use_mask
)
{
if
(
use_mask
)
{
AT_ASSERTM
(
pad_mask
.
dim
()
==
2
,
"expected 2D tensor"
);
AT_ASSERTM
(
pad_mask
.
dim
()
==
2
,
"expected 2D tensor"
);
AT_ASSERTM
(
pad_mask
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only BYTE is supported"
);
AT_ASSERTM
(
pad_mask
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only BYTE is supported"
);
}
}
return
fwd_cuda
(
return
fwd_cuda
(
is_training
,
heads
,
input
,
is_training
,
use_mask
?
static_cast
<
const
half
*>
(
pad_mask
.
data_ptr
())
heads
,
:
nullptr
,
input
,
dropout_prob
);
use_mask
?
static_cast
<
const
half
*>
(
pad_mask
.
data_ptr
())
:
nullptr
,
dropout_prob
);
}
}
torch
::
Tensor
bwd
(
torch
::
Tensor
bwd
(
bool
use_mask
,
int
heads
,
torch
::
Tensor
const
&
output_grads
,
bool
use_mask
,
torch
::
Tensor
const
&
softmax_results
,
int
heads
,
torch
::
Tensor
const
&
dropout_mask
,
float
dropout_prob
)
{
torch
::
Tensor
const
&
output_grads
,
torch
::
Tensor
const
&
softmax_results
,
torch
::
Tensor
const
&
dropout_mask
,
float
dropout_prob
)
{
AT_ASSERTM
(
output_grads
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
output_grads
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
softmax_results
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
softmax_results
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
dropout_mask
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
dropout_mask
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
output_grads
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
softmax_results
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
// AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte,
// "Only BYTE is supported");
AT_ASSERTM
(
output_grads
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
return
bwd_cuda
(
heads
,
output_grads
,
softmax_results
,
dropout_mask
,
AT_ASSERTM
(
softmax_results
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
dropout_prob
);
// AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte, "Only BYTE is supported");
return
bwd_cuda
(
heads
,
output_grads
,
softmax_results
,
dropout_mask
,
dropout_prob
);
}
}
}
//
end
namespace mask_softmax_dropout
}
// namespace
additive_
mask_softmax_dropout
}
// end namespace fused_softmax
}
// end namespace fused_softmax
}
// end namespace multihead_attn
}
// end namespace multihead_attn
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"forward"
,
&
multihead_attn
::
fused_softmax
::
additive_mask_softmax_dropout
::
fwd
,
"Self Multihead Attention masked softmax dropout -- Forward."
);
m
.
def
(
"forward"
,
m
.
def
(
"backward"
,
&
multihead_attn
::
fused_softmax
::
additive_mask_softmax_dropout
::
bwd
,
"Self Multihead Attention masked softmax dropout -- Backward."
);
&
multihead_attn
::
fused_softmax
::
additive_mask_softmax_dropout
::
fwd
,
"Self Multihead Attention masked softmax dropout -- Forward."
);
m
.
def
(
"backward"
,
&
multihead_attn
::
fused_softmax
::
additive_mask_softmax_dropout
::
bwd
,
"Self Multihead Attention masked softmax dropout -- Backward."
);
}
}
apex/contrib/csrc/multihead_attn/additive_masked_softmax_dropout_cuda.cu
View file @
db92ee13
#include <vector>
#include <math.h>
#include <iostream>
#include <iostream>
#include <math.h>
#include <vector>
#include <cuda.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_fp16.h>
//#include <cuda_profiler_api.h>
//#include <cuda_profiler_api.h>
#include <cuda_runtime.h>
#include <ATen/ATen.h>
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include <torch/extension.h>
#include "softmax.h"
#include "dropout.h"
#include "dropout.h"
#include "softmax.h"
// symbol to be automatically resolved by PyTorch libs
// symbol to be automatically resolved by PyTorch libs
extern
THCState
*
state
;
namespace
multihead_attn
{
namespace
multihead_attn
{
namespace
fused_softmax
{
namespace
fused_softmax
{
namespace
additive_mask_softmax_dropout
{
namespace
additive_mask_softmax_dropout
{
std
::
vector
<
torch
::
Tensor
>
fwd_cuda
(
std
::
vector
<
torch
::
Tensor
>
fwd_cuda
(
bool
is_training
,
int
heads
,
bool
is_training
,
torch
::
Tensor
const
&
input
,
int
heads
,
const
half
*
pad_mask
,
float
dropout_prob
)
{
torch
::
Tensor
const
&
input
,
const
half
*
pad_mask
,
float
dropout_prob
)
{
const
int
attn_batches
=
input
.
size
(
0
);
const
int
attn_batches
=
input
.
size
(
0
);
const
int
sequences
=
attn_batches
/
heads
;
const
int
sequences
=
attn_batches
/
heads
;
const
int
q_seq_len
=
input
.
size
(
1
);
const
int
q_seq_len
=
input
.
size
(
1
);
...
@@ -41,63 +35,54 @@ std::vector<torch::Tensor> fwd_cuda(
...
@@ -41,63 +35,54 @@ std::vector<torch::Tensor> fwd_cuda(
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
cublasSetStream
(
handle
,
stream
);
cublasSetStream
(
handle
,
stream
);
// 3 Intermediate Results + Output (Note: dropout intermediates are generated by ATen library code)
// 3 Intermediate Results + Output (Note: dropout intermediates are generated
// by ATen library code)
auto
act_options
=
input
.
options
().
requires_grad
(
false
);
auto
act_options
=
input
.
options
().
requires_grad
(
false
);
auto
mask_options
=
act_options
.
dtype
(
torch
::
kUInt8
);
auto
mask_options
=
act_options
.
dtype
(
torch
::
kUInt8
);
torch
::
Tensor
softmax_results
=
torch
::
empty
({
attn_batches
,
q_seq_len
,
k_seq_len
},
act_options
);
torch
::
Tensor
softmax_results
=
torch
::
Tensor
dropout_results
=
torch
::
empty
({
attn_batches
,
q_seq_len
,
k_seq_len
},
act_options
);
torch
::
empty
({
attn_batches
,
q_seq_len
,
k_seq_len
},
act_options
);
torch
::
Tensor
dropout_mask
=
torch
::
empty
({
attn_batches
,
q_seq_len
,
k_seq_len
},
mask_options
);
torch
::
Tensor
dropout_results
=
torch
::
empty
({
attn_batches
,
q_seq_len
,
k_seq_len
},
act_options
);
torch
::
Tensor
dropout_mask
=
torch
::
empty
({
attn_batches
,
q_seq_len
,
k_seq_len
},
mask_options
);
// Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax)
// Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax)
void
*
input_ptr
=
static_cast
<
void
*>
(
input
.
data_ptr
());
void
*
input_ptr
=
static_cast
<
void
*>
(
input
.
data_ptr
());
void
*
softmax_results_ptr
=
static_cast
<
void
*>
(
softmax_results
.
data_ptr
());
void
*
softmax_results_ptr
=
static_cast
<
void
*>
(
softmax_results
.
data_ptr
());
// Padded Softmax
// Padded Softmax
bool
softmax_success
=
false
;
bool
softmax_success
=
false
;
if
(
pad_mask
==
nullptr
)
{
if
(
pad_mask
==
nullptr
)
{
softmax_success
=
dispatch_softmax
<
half
,
half
,
float
>
(
softmax_success
=
dispatch_softmax
<
half
,
half
,
float
>
(
reinterpret_cast
<
half
*>
(
softmax_results_ptr
),
reinterpret_cast
<
half
*>
(
softmax_results_ptr
),
reinterpret_cast
<
const
half
*>
(
input_ptr
),
reinterpret_cast
<
const
half
*>
(
input_ptr
),
k_seq_len
,
k_seq_len
,
k_seq_len
,
attn_batches
*
q_seq_len
);
k_seq_len
,
attn_batches
*
q_seq_len
);
}
else
{
}
else
{
softmax_success
=
dispatch_additive_masked_softmax
<
half
,
half
,
float
>
(
softmax_success
=
dispatch_additive_masked_softmax
<
half
,
half
,
float
>
(
reinterpret_cast
<
half
*>
(
softmax_results_ptr
),
reinterpret_cast
<
half
*>
(
softmax_results_ptr
),
reinterpret_cast
<
const
half
*>
(
input_ptr
),
reinterpret_cast
<
const
half
*>
(
input_ptr
),
pad_mask
,
k_seq_len
,
pad_mask
,
k_seq_len
,
attn_batches
*
q_seq_len
,
k_seq_len
,
attn_batches
*
q_seq_len
/
sequences
);
k_seq_len
,
attn_batches
*
q_seq_len
,
attn_batches
*
q_seq_len
/
sequences
);
}
}
if
(
is_training
)
{
if
(
is_training
)
{
//use at:: function so that C++ version generates the same random mask as python version
// use at:: function so that C++ version generates the same random mask as
auto
dropout_tuple
=
at
::
_fused_dropout
(
softmax_results
,
1.0
f
-
dropout_prob
);
// python version
auto
dropout_tuple
=
at
::
_fused_dropout
(
softmax_results
,
1.0
f
-
dropout_prob
);
dropout_results
=
std
::
get
<
0
>
(
dropout_tuple
);
dropout_results
=
std
::
get
<
0
>
(
dropout_tuple
);
dropout_mask
=
std
::
get
<
1
>
(
dropout_tuple
);
dropout_mask
=
std
::
get
<
1
>
(
dropout_tuple
);
}
}
// Matmul2
// Matmul2
return
{
return
{
dropout_results
,
dropout_mask
,
softmax_results
};
dropout_results
,
dropout_mask
,
softmax_results
};
}
}
torch
::
Tensor
bwd_cuda
(
torch
::
Tensor
bwd_cuda
(
int
heads
,
torch
::
Tensor
const
&
output_grads
,
int
heads
,
torch
::
Tensor
const
&
softmax_results
,
torch
::
Tensor
const
&
output_grads
,
torch
::
Tensor
const
&
dropout_mask
,
float
dropout_prob
)
{
torch
::
Tensor
const
&
softmax_results
,
torch
::
Tensor
const
&
dropout_mask
,
float
dropout_prob
)
{
const
int
attn_batches
=
output_grads
.
size
(
0
);
const
int
attn_batches
=
output_grads
.
size
(
0
);
const
int
q_seq_len
=
output_grads
.
size
(
1
);
const
int
q_seq_len
=
output_grads
.
size
(
1
);
const
int
k_seq_len
=
q_seq_len
;
const
int
k_seq_len
=
q_seq_len
;
...
@@ -109,23 +94,20 @@ torch::Tensor bwd_cuda(
...
@@ -109,23 +94,20 @@ torch::Tensor bwd_cuda(
cublasSetStream
(
handle
,
stream
);
cublasSetStream
(
handle
,
stream
);
// Output Tensor Allocations
// Output Tensor Allocations
// torch::Tensor input_grads = torch::empty_like(output_grads);
// torch::Tensor input_grads = torch::empty_like(output_grads);
// Apply Dropout Mask and Scale by Dropout Probability
// Apply Dropout Mask and Scale by Dropout Probability
// Softmax Grad
// Softmax Grad
dispatch_masked_scale_softmax_backward_stream
<
half
,
half
,
float
,
false
>
(
dispatch_masked_scale_softmax_backward_stream
<
half
,
half
,
float
,
false
>
(
static_cast
<
half
*>
(
output_grads
.
data_ptr
()),
static_cast
<
half
*>
(
output_grads
.
data_ptr
()),
static_cast
<
half
*>
(
output_grads
.
data_ptr
()),
static_cast
<
half
*>
(
output_grads
.
data_ptr
()),
reinterpret_cast
<
half
const
*>
(
softmax_results
.
data_ptr
()),
reinterpret_cast
<
half
const
*>
(
softmax_results
.
data_ptr
()),
static_cast
<
uint8_t
const
*>
(
dropout_mask
.
data_ptr
()),
static_cast
<
uint8_t
const
*>
(
dropout_mask
.
data_ptr
()),
1.0
/
(
1.0
-
dropout_prob
),
1.0
/
(
1.0
-
dropout_prob
),
k_seq_len
,
k_seq_len
,
k_seq_len
,
attn_batches
*
q_seq_len
,
stream
);
k_seq_len
,
// backward pass is completely in-place
attn_batches
*
q_seq_len
,
stream
);
//backward pass is completely in-place
return
output_grads
;
return
output_grads
;
}
}
}
}
// namespace additive_mask_softmax_dropout
}
}
// namespace fused_softmax
}
}
// namespace multihead_attn
apex/contrib/csrc/multihead_attn/dropout.h
View file @
db92ee13
...
@@ -11,33 +11,22 @@
...
@@ -11,33 +11,22 @@
const
int
UNROLL
=
4
;
const
int
UNROLL
=
4
;
template
<
template
<
typename
scalar_t
,
typename
accscalar_t
,
typename
IndexType
>
typename
scalar_t
,
__global__
void
typename
accscalar_t
,
apex_fused_dropout_kernel
(
scalar_t
const
*
inputs
,
scalar_t
*
outputs
,
typename
IndexType
uint8_t
*
mask
,
IndexType
totalElements
,
accscalar_t
p
,
>
std
::
pair
<
uint64_t
,
uint64_t
>
seeds
)
{
__global__
void
apex_fused_dropout_kernel
(
scalar_t
const
*
inputs
,
accscalar_t
pinv
=
accscalar_t
(
1
)
/
p
;
scalar_t
*
outputs
,
uint8_t
*
mask
,
IndexType
totalElements
,
accscalar_t
p
,
std
::
pair
<
uint64_t
,
uint64_t
>
seeds
)
{
accscalar_t
pinv
=
accscalar_t
(
1
)
/
p
;
IndexType
idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
IndexType
idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
curandStatePhilox4_32_10_t
state
;
curandStatePhilox4_32_10_t
state
;
curand_init
(
curand_init
(
seeds
.
first
,
idx
,
seeds
.
second
,
&
state
);
seeds
.
first
,
idx
,
seeds
.
second
,
&
state
);
IndexType
rounded_size
=
((
totalElements
-
1
)
/
(
blockDim
.
x
*
gridDim
.
x
*
UNROLL
)
+
1
)
*
blockDim
.
x
*
gridDim
.
x
*
UNROLL
;
IndexType
rounded_size
=
for
(
IndexType
linearIndex
=
idx
;
((
totalElements
-
1
)
/
(
blockDim
.
x
*
gridDim
.
x
*
UNROLL
)
+
1
)
*
linearIndex
<
rounded_size
;
blockDim
.
x
*
gridDim
.
x
*
UNROLL
;
linearIndex
+=
gridDim
.
x
*
blockDim
.
x
*
UNROLL
)
{
for
(
IndexType
linearIndex
=
idx
;
linearIndex
<
rounded_size
;
linearIndex
+=
gridDim
.
x
*
blockDim
.
x
*
UNROLL
)
{
float4
rand
=
curand_uniform4
(
&
state
);
float4
rand
=
curand_uniform4
(
&
state
);
scalar_t
src
[
UNROLL
];
scalar_t
src
[
UNROLL
];
rand
.
x
=
rand
.
x
<=
p
;
rand
.
x
=
rand
.
x
<=
p
;
...
@@ -54,7 +43,7 @@ __global__ void apex_fused_dropout_kernel(scalar_t const *inputs,
...
@@ -54,7 +43,7 @@ __global__ void apex_fused_dropout_kernel(scalar_t const *inputs,
for
(
int
ii
=
0
;
ii
<
UNROLL
;
ii
++
)
{
for
(
int
ii
=
0
;
ii
<
UNROLL
;
ii
++
)
{
IndexType
li
=
linearIndex
+
blockDim
.
x
*
gridDim
.
x
*
ii
;
IndexType
li
=
linearIndex
+
blockDim
.
x
*
gridDim
.
x
*
ii
;
if
(
li
<
totalElements
)
{
if
(
li
<
totalElements
)
{
outputs
[
li
]
=
src
[
ii
]
*
(
&
rand
.
x
)[
ii
]
*
pinv
;
outputs
[
li
]
=
src
[
ii
]
*
(
&
rand
.
x
)[
ii
]
*
pinv
;
mask
[
li
]
=
(
uint8_t
)(
&
rand
.
x
)[
ii
];
mask
[
li
]
=
(
uint8_t
)(
&
rand
.
x
)[
ii
];
}
}
}
}
...
@@ -62,34 +51,23 @@ __global__ void apex_fused_dropout_kernel(scalar_t const *inputs,
...
@@ -62,34 +51,23 @@ __global__ void apex_fused_dropout_kernel(scalar_t const *inputs,
}
}
}
}
template
<
template
<
typename
scalar_t
,
typename
accscalar_t
,
typename
IndexType
>
typename
scalar_t
,
typename
accscalar_t
,
typename
IndexType
>
__global__
void
apex_dropout_add_kernel
(
scalar_t
const
*
inputs
,
__global__
void
apex_dropout_add_kernel
(
scalar_t
const
*
inputs
,
scalar_t
const
*
add_inputs
,
scalar_t
const
*
add_inputs
,
scalar_t
*
outputs
,
scalar_t
*
outputs
,
uint8_t
*
mask
,
uint8_t
*
mask
,
IndexType
totalElements
,
accscalar_t
p
,
IndexType
totalElements
,
std
::
pair
<
uint64_t
,
uint64_t
>
seeds
)
{
accscalar_t
p
,
accscalar_t
pinv
=
accscalar_t
(
1
)
/
p
;
std
::
pair
<
uint64_t
,
uint64_t
>
seeds
)
{
accscalar_t
pinv
=
accscalar_t
(
1
)
/
p
;
IndexType
idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
IndexType
idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
curandStatePhilox4_32_10_t
state
;
curandStatePhilox4_32_10_t
state
;
curand_init
(
curand_init
(
seeds
.
first
,
idx
,
seeds
.
second
,
&
state
);
seeds
.
first
,
idx
,
seeds
.
second
,
&
state
);
IndexType
rounded_size
=
((
totalElements
-
1
)
/
(
blockDim
.
x
*
gridDim
.
x
*
UNROLL
)
+
1
)
*
blockDim
.
x
*
gridDim
.
x
*
UNROLL
;
IndexType
rounded_size
=
for
(
IndexType
linearIndex
=
idx
;
((
totalElements
-
1
)
/
(
blockDim
.
x
*
gridDim
.
x
*
UNROLL
)
+
1
)
*
linearIndex
<
rounded_size
;
blockDim
.
x
*
gridDim
.
x
*
UNROLL
;
linearIndex
+=
gridDim
.
x
*
blockDim
.
x
*
UNROLL
)
{
for
(
IndexType
linearIndex
=
idx
;
linearIndex
<
rounded_size
;
linearIndex
+=
gridDim
.
x
*
blockDim
.
x
*
UNROLL
)
{
float4
rand
=
curand_uniform4
(
&
state
);
float4
rand
=
curand_uniform4
(
&
state
);
scalar_t
src
[
UNROLL
];
scalar_t
src
[
UNROLL
];
scalar_t
add_src
[
UNROLL
];
scalar_t
add_src
[
UNROLL
];
...
@@ -108,7 +86,8 @@ __global__ void apex_dropout_add_kernel(scalar_t const *inputs,
...
@@ -108,7 +86,8 @@ __global__ void apex_dropout_add_kernel(scalar_t const *inputs,
IndexType
li
=
linearIndex
+
blockDim
.
x
*
gridDim
.
x
*
ii
;
IndexType
li
=
linearIndex
+
blockDim
.
x
*
gridDim
.
x
*
ii
;
if
(
li
<
totalElements
)
{
if
(
li
<
totalElements
)
{
accscalar_t
int1
=
src
[
ii
]
*
(
&
rand
.
x
)[
ii
]
*
pinv
;
accscalar_t
int1
=
src
[
ii
]
*
(
&
rand
.
x
)[
ii
]
*
pinv
;
outputs
[
li
]
=
static_cast
<
scalar_t
>
(
static_cast
<
accscalar_t
>
(
add_src
[
ii
])
+
int1
);
outputs
[
li
]
=
static_cast
<
scalar_t
>
(
static_cast
<
accscalar_t
>
(
add_src
[
ii
])
+
int1
);
mask
[
li
]
=
(
uint8_t
)(
&
rand
.
x
)[
ii
];
mask
[
li
]
=
(
uint8_t
)(
&
rand
.
x
)[
ii
];
}
}
}
}
...
@@ -116,22 +95,16 @@ __global__ void apex_dropout_add_kernel(scalar_t const *inputs,
...
@@ -116,22 +95,16 @@ __global__ void apex_dropout_add_kernel(scalar_t const *inputs,
}
}
}
}
template
<
template
<
typename
scalar_t
,
typename
accscalar_t
,
typename
IndexType
>
typename
scalar_t
,
__global__
void
apex_add_kernel
(
scalar_t
const
*
inputs
,
typename
accscalar_t
,
scalar_t
const
*
add_inputs
,
scalar_t
*
outputs
,
typename
IndexType
IndexType
totalElements
)
{
>
__global__
void
apex_add_kernel
(
scalar_t
const
*
inputs
,
scalar_t
const
*
add_inputs
,
scalar_t
*
outputs
,
IndexType
totalElements
)
{
IndexType
idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
IndexType
idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
IndexType
rounded_size
=
((
totalElements
-
1
)
/
(
blockDim
.
x
*
gridDim
.
x
*
UNROLL
)
+
1
)
*
blockDim
.
x
*
gridDim
.
x
*
UNROLL
;
IndexType
rounded_size
=
for
(
IndexType
linearIndex
=
idx
;
((
totalElements
-
1
)
/
(
blockDim
.
x
*
gridDim
.
x
*
UNROLL
)
+
1
)
*
linearIndex
<
rounded_size
;
blockDim
.
x
*
gridDim
.
x
*
UNROLL
;
linearIndex
+=
gridDim
.
x
*
blockDim
.
x
*
UNROLL
)
{
for
(
IndexType
linearIndex
=
idx
;
linearIndex
<
rounded_size
;
linearIndex
+=
gridDim
.
x
*
blockDim
.
x
*
UNROLL
)
{
scalar_t
src
[
UNROLL
];
scalar_t
src
[
UNROLL
];
scalar_t
add_src
[
UNROLL
];
scalar_t
add_src
[
UNROLL
];
for
(
int
ii
=
0
;
ii
<
UNROLL
;
ii
++
)
{
for
(
int
ii
=
0
;
ii
<
UNROLL
;
ii
++
)
{
...
@@ -151,23 +124,17 @@ __global__ void apex_add_kernel( scalar_t const *inputs,
...
@@ -151,23 +124,17 @@ __global__ void apex_add_kernel( scalar_t const *inputs,
}
}
}
}
template
<
typename
scalar_t
,
template
<
typename
scalar_t
,
typename
accscalar_t
,
typename
IndexType
>
typename
accscalar_t
,
typename
IndexType
>
__global__
void
apex_masked_scale_kernel
(
scalar_t
const
*
inputs
,
__global__
void
apex_masked_scale_kernel
(
scalar_t
const
*
inputs
,
scalar_t
*
outputs
,
scalar_t
*
outputs
,
uint8_t
const
*
mask
,
uint8_t
const
*
mask
,
IndexType
totalElements
,
IndexType
totalElements
,
accscalar_t
scale
accscalar_t
scale
)
{
)
{
IndexType
idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
IndexType
idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
IndexType
rounded_size
=
((
totalElements
-
1
)
/
(
blockDim
.
x
*
gridDim
.
x
*
UNROLL
)
+
1
)
*
blockDim
.
x
*
gridDim
.
x
*
UNROLL
;
IndexType
rounded_size
=
for
(
IndexType
linearIndex
=
idx
;
((
totalElements
-
1
)
/
(
blockDim
.
x
*
gridDim
.
x
*
UNROLL
)
+
1
)
*
linearIndex
<
rounded_size
;
blockDim
.
x
*
gridDim
.
x
*
UNROLL
;
linearIndex
+
=
gridDim
.
x
*
blockDim
.
x
*
UNROLL
)
for
(
IndexType
linearIndex
=
idx
;
linearIndex
<
rounded_size
;
{
linearIndex
+=
gridDim
.
x
*
blockDim
.
x
*
UNROLL
)
{
scalar_t
src
[
UNROLL
];
scalar_t
src
[
UNROLL
];
scalar_t
msk
[
UNROLL
];
scalar_t
msk
[
UNROLL
];
for
(
int
ii
=
0
;
ii
<
UNROLL
;
ii
++
)
{
for
(
int
ii
=
0
;
ii
<
UNROLL
;
ii
++
)
{
...
@@ -180,33 +147,34 @@ __global__ void apex_masked_scale_kernel(scalar_t const *inputs,
...
@@ -180,33 +147,34 @@ __global__ void apex_masked_scale_kernel(scalar_t const *inputs,
for
(
int
ii
=
0
;
ii
<
UNROLL
;
ii
++
)
{
for
(
int
ii
=
0
;
ii
<
UNROLL
;
ii
++
)
{
IndexType
li
=
linearIndex
+
blockDim
.
x
*
gridDim
.
x
*
ii
;
IndexType
li
=
linearIndex
+
blockDim
.
x
*
gridDim
.
x
*
ii
;
if
(
li
<
totalElements
)
{
if
(
li
<
totalElements
)
{
outputs
[
li
]
=
static_cast
<
accscalar_t
>
(
src
[
ii
])
*
scale
*
static_cast
<
accscalar_t
>
(
msk
[
ii
]);
outputs
[
li
]
=
static_cast
<
accscalar_t
>
(
src
[
ii
])
*
scale
*
static_cast
<
accscalar_t
>
(
msk
[
ii
]);
}
}
}
}
}
}
}
}
template
<
template
<
typename
scalar_t
,
typename
accscalar_t
,
typename
IndexType
>
typename
scalar_t
,
void
apex_fused_dropout_cuda
(
scalar_t
const
*
inputs
,
scalar_t
*
outputs
,
typename
accscalar_t
,
uint8_t
*
mask
,
IndexType
totalElements
,
typename
IndexType
accscalar_t
p
)
{
>
void
apex_fused_dropout_cuda
(
scalar_t
const
*
inputs
,
scalar_t
*
outputs
,
uint8_t
*
mask
,
IndexType
totalElements
,
accscalar_t
p
)
{
auto
gen
=
at
::
cuda
::
detail
::
getDefaultCUDAGenerator
();
auto
gen
=
at
::
cuda
::
detail
::
getDefaultCUDAGenerator
();
int
block_size
=
256
;
int
block_size
=
256
;
dim3
dim_block
(
block_size
);
dim3
dim_block
(
block_size
);
dim3
grid
((
totalElements
+
block_size
-
1
)
/
block_size
);
dim3
grid
((
totalElements
+
block_size
-
1
)
/
block_size
);
unsigned
int
blocks_per_sm
=
at
::
cuda
::
getCurrentDeviceProperties
()
->
maxThreadsPerMultiProcessor
/
block_size
;
unsigned
int
blocks_per_sm
=
grid
.
x
=
std
::
min
((
unsigned
int
)
at
::
cuda
::
getCurrentDeviceProperties
()
->
multiProcessorCount
*
blocks_per_sm
,
grid
.
x
);
at
::
cuda
::
getCurrentDeviceProperties
()
->
maxThreadsPerMultiProcessor
/
block_size
;
grid
.
x
=
std
::
min
((
unsigned
int
)
at
::
cuda
::
getCurrentDeviceProperties
()
->
multiProcessorCount
*
blocks_per_sm
,
grid
.
x
);
//number of times random will be generated per thread, to offset philox counter in the random state
// number of times random will be generated per thread, to offset philox
int64_t
counter_offset
=
((
totalElements
-
1
)
/
(
block_size
*
grid
.
x
*
UNROLL
)
+
1
)
*
UNROLL
;
// counter in the random state
int64_t
counter_offset
=
((
totalElements
-
1
)
/
(
block_size
*
grid
.
x
*
UNROLL
)
+
1
)
*
UNROLL
;
std
::
pair
<
uint64_t
,
uint64_t
>
rng_engine_inputs
;
std
::
pair
<
uint64_t
,
uint64_t
>
rng_engine_inputs
;
{
{
// See Note [Acquire lock when using random generators]
// See Note [Acquire lock when using random generators]
...
@@ -215,36 +183,39 @@ void apex_fused_dropout_cuda(scalar_t const *inputs,
...
@@ -215,36 +183,39 @@ void apex_fused_dropout_cuda(scalar_t const *inputs,
rng_engine_inputs
=
gen
->
philox_engine_inputs
(
counter_offset
);
rng_engine_inputs
=
gen
->
philox_engine_inputs
(
counter_offset
);
#else
#else
std
::
lock_guard
<
std
::
mutex
>
lock
(
gen
.
mutex
());
std
::
lock_guard
<
std
::
mutex
>
lock
(
gen
.
mutex
());
rng_engine_inputs
=
at
::
check_generator
<
at
::
CUDAGeneratorImpl
>
(
gen
)
->
philox_engine_inputs
(
counter_offset
);
rng_engine_inputs
=
at
::
check_generator
<
at
::
CUDAGeneratorImpl
>
(
gen
)
->
philox_engine_inputs
(
counter_offset
);
#endif
#endif
}
}
apex_fused_dropout_kernel
<
scalar_t
,
accscalar_t
,
IndexType
><<<
grid
,
dim_block
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
inputs
,
outputs
,
mask
,
totalElements
,
p
,
rng_engine_inputs
);
apex_fused_dropout_kernel
<
scalar_t
,
accscalar_t
,
IndexType
>
<<<
grid
,
dim_block
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
inputs
,
outputs
,
mask
,
totalElements
,
p
,
rng_engine_inputs
);
C10_CUDA_CHECK
(
cudaGetLastError
());
C10_CUDA_CHECK
(
cudaGetLastError
());
}
}
template
<
template
<
typename
scalar_t
,
typename
accscalar_t
,
typename
IndexType
>
typename
scalar_t
,
void
apex_dropout_add_cuda
(
scalar_t
const
*
inputs
,
scalar_t
const
*
add_inputs
,
typename
accscalar_t
,
scalar_t
*
outputs
,
uint8_t
*
mask
,
typename
IndexType
IndexType
totalElements
,
accscalar_t
p
)
{
>
void
apex_dropout_add_cuda
(
scalar_t
const
*
inputs
,
scalar_t
const
*
add_inputs
,
scalar_t
*
outputs
,
uint8_t
*
mask
,
IndexType
totalElements
,
accscalar_t
p
)
{
auto
gen
=
at
::
cuda
::
detail
::
getDefaultCUDAGenerator
();
auto
gen
=
at
::
cuda
::
detail
::
getDefaultCUDAGenerator
();
int
block_size
=
256
;
int
block_size
=
256
;
dim3
dim_block
(
block_size
);
dim3
dim_block
(
block_size
);
dim3
grid
((
totalElements
+
block_size
-
1
)
/
block_size
);
dim3
grid
((
totalElements
+
block_size
-
1
)
/
block_size
);
unsigned
int
blocks_per_sm
=
at
::
cuda
::
getCurrentDeviceProperties
()
->
maxThreadsPerMultiProcessor
/
block_size
;
unsigned
int
blocks_per_sm
=
grid
.
x
=
std
::
min
((
unsigned
int
)
at
::
cuda
::
getCurrentDeviceProperties
()
->
multiProcessorCount
*
blocks_per_sm
,
grid
.
x
);
at
::
cuda
::
getCurrentDeviceProperties
()
->
maxThreadsPerMultiProcessor
/
block_size
;
grid
.
x
=
std
::
min
((
unsigned
int
)
at
::
cuda
::
getCurrentDeviceProperties
()
->
multiProcessorCount
*
blocks_per_sm
,
grid
.
x
);
//number of times random will be generated per thread, to offset philox counter in the random state
// number of times random will be generated per thread, to offset philox
int64_t
counter_offset
=
((
totalElements
-
1
)
/
(
block_size
*
grid
.
x
*
UNROLL
)
+
1
)
*
UNROLL
;
// counter in the random state
int64_t
counter_offset
=
((
totalElements
-
1
)
/
(
block_size
*
grid
.
x
*
UNROLL
)
+
1
)
*
UNROLL
;
std
::
pair
<
uint64_t
,
uint64_t
>
rng_engine_inputs
;
std
::
pair
<
uint64_t
,
uint64_t
>
rng_engine_inputs
;
{
{
// See Note [Acquire lock when using random generators]
// See Note [Acquire lock when using random generators]
...
@@ -253,54 +224,56 @@ void apex_dropout_add_cuda(scalar_t const *inputs,
...
@@ -253,54 +224,56 @@ void apex_dropout_add_cuda(scalar_t const *inputs,
rng_engine_inputs
=
gen
->
philox_engine_inputs
(
counter_offset
);
rng_engine_inputs
=
gen
->
philox_engine_inputs
(
counter_offset
);
#else
#else
std
::
lock_guard
<
std
::
mutex
>
lock
(
gen
.
mutex
());
std
::
lock_guard
<
std
::
mutex
>
lock
(
gen
.
mutex
());
rng_engine_inputs
=
at
::
check_generator
<
at
::
CUDAGeneratorImpl
>
(
gen
)
->
philox_engine_inputs
(
counter_offset
);
rng_engine_inputs
=
at
::
check_generator
<
at
::
CUDAGeneratorImpl
>
(
gen
)
->
philox_engine_inputs
(
counter_offset
);
#endif
#endif
}
}
apex_dropout_add_kernel
<
scalar_t
,
accscalar_t
,
IndexType
><<<
grid
,
dim_block
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
inputs
,
add_inputs
,
outputs
,
mask
,
totalElements
,
p
,
rng_engine_inputs
);
apex_dropout_add_kernel
<
scalar_t
,
accscalar_t
,
IndexType
>
<<<
grid
,
dim_block
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
inputs
,
add_inputs
,
outputs
,
mask
,
totalElements
,
p
,
rng_engine_inputs
);
C10_CUDA_CHECK
(
cudaGetLastError
());
C10_CUDA_CHECK
(
cudaGetLastError
());
}
}
template
<
template
<
typename
scalar_t
,
typename
accscalar_t
,
typename
IndexType
>
typename
scalar_t
,
void
apex_add_cuda
(
scalar_t
const
*
inputs
,
scalar_t
const
*
add_inputs
,
typename
accscalar_t
,
scalar_t
*
outputs
,
IndexType
totalElements
)
{
typename
IndexType
>
void
apex_add_cuda
(
scalar_t
const
*
inputs
,
scalar_t
const
*
add_inputs
,
scalar_t
*
outputs
,
IndexType
totalElements
)
{
int
block_size
=
256
;
int
block_size
=
256
;
dim3
dim_block
(
block_size
);
dim3
dim_block
(
block_size
);
dim3
grid
((
totalElements
+
block_size
-
1
)
/
block_size
);
dim3
grid
((
totalElements
+
block_size
-
1
)
/
block_size
);
unsigned
int
blocks_per_sm
=
at
::
cuda
::
getCurrentDeviceProperties
()
->
maxThreadsPerMultiProcessor
/
block_size
;
unsigned
int
blocks_per_sm
=
grid
.
x
=
std
::
min
((
unsigned
int
)
at
::
cuda
::
getCurrentDeviceProperties
()
->
multiProcessorCount
*
blocks_per_sm
,
grid
.
x
);
at
::
cuda
::
getCurrentDeviceProperties
()
->
maxThreadsPerMultiProcessor
/
block_size
;
grid
.
x
=
std
::
min
((
unsigned
int
)
at
::
cuda
::
getCurrentDeviceProperties
()
->
multiProcessorCount
*
blocks_per_sm
,
grid
.
x
);
apex_add_kernel
<
scalar_t
,
accscalar_t
,
IndexType
><<<
grid
,
dim_block
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
inputs
,
add_inputs
,
outputs
,
totalElements
);
apex_add_kernel
<
scalar_t
,
accscalar_t
,
IndexType
>
<<<
grid
,
dim_block
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
inputs
,
add_inputs
,
outputs
,
totalElements
);
C10_CUDA_CHECK
(
cudaGetLastError
());
C10_CUDA_CHECK
(
cudaGetLastError
());
}
}
template
<
typename
scalar_t
,
template
<
typename
scalar_t
,
typename
accscalar_t
,
typename
IndexType
>
typename
accscalar_t
,
void
apex_masked_scale_cuda
(
scalar_t
const
*
inputs
,
scalar_t
*
outputs
,
typename
IndexType
uint8_t
const
*
mask
,
IndexType
totalElements
,
>
accscalar_t
scale
)
{
void
apex_masked_scale_cuda
(
scalar_t
const
*
inputs
,
scalar_t
*
outputs
,
uint8_t
const
*
mask
,
IndexType
totalElements
,
accscalar_t
scale
)
{
int
block_size
=
256
;
int
block_size
=
256
;
dim3
dim_block
(
block_size
);
dim3
dim_block
(
block_size
);
dim3
grid
((
totalElements
+
block_size
-
1
)
/
block_size
);
dim3
grid
((
totalElements
+
block_size
-
1
)
/
block_size
);
unsigned
int
blocks_per_sm
=
at
::
cuda
::
getCurrentDeviceProperties
()
->
maxThreadsPerMultiProcessor
/
block_size
;
unsigned
int
blocks_per_sm
=
grid
.
x
=
std
::
min
((
unsigned
int
)
at
::
cuda
::
getCurrentDeviceProperties
()
->
multiProcessorCount
*
blocks_per_sm
,
grid
.
x
);
at
::
cuda
::
getCurrentDeviceProperties
()
->
maxThreadsPerMultiProcessor
/
block_size
;
grid
.
x
=
std
::
min
((
unsigned
int
)
at
::
cuda
::
getCurrentDeviceProperties
()
->
multiProcessorCount
*
blocks_per_sm
,
grid
.
x
);
apex_masked_scale_kernel
<
scalar_t
,
accscalar_t
,
IndexType
><<<
grid
,
dim_block
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
inputs
,
outputs
,
mask
,
totalElements
,
scale
);
apex_masked_scale_kernel
<
scalar_t
,
accscalar_t
,
IndexType
>
<<<
grid
,
dim_block
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
inputs
,
outputs
,
mask
,
totalElements
,
scale
);
C10_CUDA_CHECK
(
cudaGetLastError
());
C10_CUDA_CHECK
(
cudaGetLastError
());
}
}
apex/contrib/csrc/multihead_attn/encdec_multihead_attn_cpp.cpp
View file @
db92ee13
...
@@ -5,103 +5,79 @@ namespace multihead_attn {
...
@@ -5,103 +5,79 @@ namespace multihead_attn {
namespace
encdec
{
namespace
encdec
{
namespace
rocblas_gemmex
{
namespace
rocblas_gemmex
{
std
::
vector
<
torch
::
Tensor
>
fwd_cuda
(
std
::
vector
<
torch
::
Tensor
>
fwd_cuda
(
bool
use_time_mask
,
bool
is_training
,
bool
use_time_mask
,
int
heads
,
torch
::
Tensor
const
&
inputs_q
,
bool
is_training
,
torch
::
Tensor
const
&
inputs_kv
,
int
heads
,
torch
::
Tensor
const
&
input_weights_q
,
torch
::
Tensor
const
&
inputs_q
,
torch
::
Tensor
const
&
input_weights_kv
,
torch
::
Tensor
const
&
inputs_kv
,
torch
::
Tensor
const
&
output_weights
,
torch
::
Tensor
const
&
input_weights_q
,
const
uint8_t
*
pad_mask
,
torch
::
Tensor
const
&
input_weights_kv
,
float
dropout_prob
);
torch
::
Tensor
const
&
output_weights
,
const
uint8_t
*
pad_mask
,
float
dropout_prob
);
std
::
vector
<
torch
::
Tensor
>
bwd_cuda
(
std
::
vector
<
torch
::
Tensor
>
bwd_cuda
(
int
heads
,
int
heads
,
torch
::
Tensor
const
&
output_grads
,
torch
::
Tensor
const
&
output_grads
,
torch
::
Tensor
const
&
matmul2_results
,
torch
::
Tensor
const
&
dropout_results
,
torch
::
Tensor
const
&
matmul2_results
,
torch
::
Tensor
const
&
softmax_results
,
torch
::
Tensor
const
&
dropout_results
,
torch
::
Tensor
const
&
input_lin_q_results
,
torch
::
Tensor
const
&
softmax_results
,
torch
::
Tensor
const
&
input_lin_kv_results
,
torch
::
Tensor
const
&
inputs_q
,
torch
::
Tensor
const
&
input_lin_q_results
,
torch
::
Tensor
const
&
inputs_kv
,
torch
::
Tensor
const
&
input_weights_q
,
torch
::
Tensor
const
&
input_lin_kv_results
,
torch
::
Tensor
const
&
input_weights_kv
,
torch
::
Tensor
const
&
output_weights
,
torch
::
Tensor
const
&
inputs_q
,
torch
::
Tensor
const
&
dropout_mask
,
float
dropout_prob
);
torch
::
Tensor
const
&
inputs_kv
,
torch
::
Tensor
const
&
input_weights_q
,
torch
::
Tensor
const
&
input_weights_kv
,
torch
::
Tensor
const
&
output_weights
,
torch
::
Tensor
const
&
dropout_mask
,
float
dropout_prob
);
// C++ interface
// C++ interface
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CUDA(x) \
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
#define CHECK_CONTIGUOUS(x) \
AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
std
::
vector
<
torch
::
Tensor
>
fwd
(
std
::
vector
<
torch
::
Tensor
>
bool
use_mask
,
fwd
(
bool
use_mask
,
bool
use_time_mask
,
bool
is_training
,
int
heads
,
bool
use_time_mask
,
torch
::
Tensor
const
&
inputs_q
,
torch
::
Tensor
const
&
inputs_kv
,
bool
is_training
,
torch
::
Tensor
const
&
input_weights_q
,
torch
::
Tensor
const
&
input_weights_kv
,
int
heads
,
torch
::
Tensor
const
&
output_weights
,
torch
::
Tensor
const
&
pad_mask
,
torch
::
Tensor
const
&
inputs_q
,
float
dropout_prob
)
{
torch
::
Tensor
const
&
inputs_kv
,
torch
::
Tensor
const
&
input_weights_q
,
torch
::
Tensor
const
&
input_weights_kv
,
torch
::
Tensor
const
&
output_weights
,
torch
::
Tensor
const
&
pad_mask
,
float
dropout_prob
)
{
AT_ASSERTM
(
inputs_q
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
inputs_q
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
inputs_kv
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
inputs_kv
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
input_weights_q
.
dim
()
==
2
,
"expected 2D tensor"
);
AT_ASSERTM
(
input_weights_q
.
dim
()
==
2
,
"expected 2D tensor"
);
AT_ASSERTM
(
input_weights_kv
.
dim
()
==
2
,
"expected 2D tensor"
);
AT_ASSERTM
(
input_weights_kv
.
dim
()
==
2
,
"expected 2D tensor"
);
AT_ASSERTM
(
output_weights
.
dim
()
==
2
,
"expected 2D tensor"
);
AT_ASSERTM
(
output_weights
.
dim
()
==
2
,
"expected 2D tensor"
);
AT_ASSERTM
(
inputs_q
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
inputs_q
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
AT_ASSERTM
(
inputs_kv
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
"Only HALF is supported"
);
AT_ASSERTM
(
input_weights_q
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
inputs_kv
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
AT_ASSERTM
(
input_weights_kv
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
"Only HALF is supported"
);
AT_ASSERTM
(
output_weights
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
input_weights_q
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
input_weights_kv
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
output_weights
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
if
(
use_mask
)
{
if
(
use_mask
)
{
AT_ASSERTM
(
pad_mask
.
dim
()
==
2
,
"expected 2D tensor"
);
AT_ASSERTM
(
pad_mask
.
dim
()
==
2
,
"expected 2D tensor"
);
AT_ASSERTM
(
pad_mask
.
type
().
scalarType
()
==
at
::
ScalarType
::
Byte
,
"Only BYTE is supported"
);
AT_ASSERTM
(
pad_mask
.
type
().
scalarType
()
==
at
::
ScalarType
::
Byte
,
"Only BYTE is supported"
);
}
}
return
fwd_cuda
(
return
fwd_cuda
(
use_time_mask
,
is_training
,
heads
,
inputs_q
,
inputs_kv
,
use_time_mask
,
input_weights_q
,
input_weights_kv
,
output_weights
,
is_training
,
use_mask
?
static_cast
<
const
uint8_t
*>
(
pad_mask
.
data_ptr
())
heads
,
:
nullptr
,
inputs_q
,
dropout_prob
);
inputs_kv
,
input_weights_q
,
input_weights_kv
,
output_weights
,
use_mask
?
static_cast
<
const
uint8_t
*>
(
pad_mask
.
data_ptr
())
:
nullptr
,
dropout_prob
);
}
}
std
::
vector
<
torch
::
Tensor
>
bwd
(
std
::
vector
<
torch
::
Tensor
>
int
heads
,
bwd
(
int
heads
,
torch
::
Tensor
const
&
output_grads
,
torch
::
Tensor
const
&
output_grads
,
torch
::
Tensor
const
&
matmul2_results
,
torch
::
Tensor
const
&
dropout_results
,
torch
::
Tensor
const
&
matmul2_results
,
torch
::
Tensor
const
&
softmax_results
,
torch
::
Tensor
const
&
dropout_results
,
torch
::
Tensor
const
&
input_lin_q_results
,
torch
::
Tensor
const
&
softmax_results
,
torch
::
Tensor
const
&
input_lin_kv_results
,
torch
::
Tensor
const
&
inputs_q
,
torch
::
Tensor
const
&
input_lin_q_results
,
torch
::
Tensor
const
&
inputs_kv
,
torch
::
Tensor
const
&
input_weights_q
,
torch
::
Tensor
const
&
input_lin_kv_results
,
torch
::
Tensor
const
&
input_weights_kv
,
torch
::
Tensor
const
&
output_weights
,
torch
::
Tensor
const
&
inputs_q
,
torch
::
Tensor
const
&
dropout_mask
,
float
dropout_prob
)
{
torch
::
Tensor
const
&
inputs_kv
,
torch
::
Tensor
const
&
input_weights_q
,
torch
::
Tensor
const
&
input_weights_kv
,
torch
::
Tensor
const
&
output_weights
,
torch
::
Tensor
const
&
dropout_mask
,
float
dropout_prob
)
{
AT_ASSERTM
(
output_grads
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
output_grads
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
matmul2_results
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
matmul2_results
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
dropout_results
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
dropout_results
.
dim
()
==
3
,
"expected 3D tensor"
);
...
@@ -115,35 +91,35 @@ std::vector<torch::Tensor> bwd(
...
@@ -115,35 +91,35 @@ std::vector<torch::Tensor> bwd(
AT_ASSERTM
(
output_weights
.
dim
()
==
2
,
"expected 2D tensor"
);
AT_ASSERTM
(
output_weights
.
dim
()
==
2
,
"expected 2D tensor"
);
AT_ASSERTM
(
dropout_mask
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
dropout_mask
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
output_grads
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
output_grads
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
AT_ASSERTM
(
matmul2_results
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
"Only HALF is supported"
);
AT_ASSERTM
(
dropout_results
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
matmul2_results
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
AT_ASSERTM
(
softmax_results
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
"Only HALF is supported"
);
AT_ASSERTM
(
input_lin_q_results
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
dropout_results
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
AT_ASSERTM
(
input_lin_kv_results
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
"Only HALF is supported"
);
AT_ASSERTM
(
inputs_q
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
softmax_results
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
AT_ASSERTM
(
inputs_kv
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
"Only HALF is supported"
);
AT_ASSERTM
(
input_weights_q
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
input_lin_q_results
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
AT_ASSERTM
(
input_weights_kv
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
"Only HALF is supported"
);
AT_ASSERTM
(
output_weights
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
input_lin_kv_results
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
AT_ASSERTM
(
dropout_mask
.
type
().
scalarType
()
==
at
::
ScalarType
::
Byte
,
"Only BYTE is supported"
);
"Only HALF is supported"
);
AT_ASSERTM
(
inputs_q
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
inputs_kv
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
input_weights_q
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
input_weights_kv
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
output_weights
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
dropout_mask
.
type
().
scalarType
()
==
at
::
ScalarType
::
Byte
,
"Only BYTE is supported"
);
return
bwd_cuda
(
return
bwd_cuda
(
heads
,
output_grads
,
matmul2_results
,
dropout_results
,
heads
,
softmax_results
,
input_lin_q_results
,
input_lin_kv_results
,
output_grads
,
inputs_q
,
inputs_kv
,
input_weights_q
,
input_weights_kv
,
matmul2_results
,
output_weights
,
dropout_mask
,
dropout_prob
);
dropout_results
,
softmax_results
,
input_lin_q_results
,
input_lin_kv_results
,
inputs_q
,
inputs_kv
,
input_weights_q
,
input_weights_kv
,
output_weights
,
dropout_mask
,
dropout_prob
);
}
}
}
// end namespace rocblas_gemm_ex
}
// end namespace rocblas_gemm_ex
...
...
apex/contrib/csrc/multihead_attn/encdec_multihead_attn_cuda.cu
View file @
db92ee13
#include <vector>
#include <math.h>
#include <iostream>
#include <iostream>
#include <math.h>
#include <vector>
#include <cuda.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_fp16.h>
//#include <cuda_profiler_api.h>
//#include <cuda_profiler_api.h>
#include <cuda_runtime.h>
#include <ATen/ATen.h>
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include <torch/extension.h>
#include "strided_batched_gemm.h"
#include "softmax.h"
#include "dropout.h"
#include "dropout.h"
#include "layer_norm.h"
#include "layer_norm.h"
#include "softmax.h"
// symbol to be automatically resolved by PyTorch libs
#include "strided_batched_gemm.h"
extern
THCState
*
state
;
namespace
multihead_attn
{
namespace
multihead_attn
{
namespace
encdec
{
namespace
encdec
{
namespace
rocblas_gemmex
{
namespace
rocblas_gemmex
{
std
::
vector
<
torch
::
Tensor
>
fwd_cuda
(
std
::
vector
<
torch
::
Tensor
>
fwd_cuda
(
bool
use_time_mask
,
bool
is_training
,
bool
use_time_mask
,
int
heads
,
torch
::
Tensor
const
&
inputs_q
,
bool
is_training
,
torch
::
Tensor
const
&
inputs_kv
,
int
heads
,
torch
::
Tensor
const
&
input_weights_q
,
torch
::
Tensor
const
&
inputs_q
,
torch
::
Tensor
const
&
input_weights_kv
,
torch
::
Tensor
const
&
inputs_kv
,
torch
::
Tensor
const
&
output_weights
,
torch
::
Tensor
const
&
input_weights_q
,
const
uint8_t
*
pad_mask
,
torch
::
Tensor
const
&
input_weights_kv
,
float
dropout_prob
)
{
torch
::
Tensor
const
&
output_weights
,
const
uint8_t
*
pad_mask
,
float
dropout_prob
)
{
const
int
embed_dim
=
inputs_q
.
size
(
2
);
const
int
embed_dim
=
inputs_q
.
size
(
2
);
const
int
sequences
=
inputs_q
.
size
(
1
);
const
int
sequences
=
inputs_q
.
size
(
1
);
const
int
q_seq_len
=
inputs_q
.
size
(
0
);
const
int
q_seq_len
=
inputs_q
.
size
(
0
);
...
@@ -48,7 +39,7 @@ std::vector<torch::Tensor> fwd_cuda(
...
@@ -48,7 +39,7 @@ std::vector<torch::Tensor> fwd_cuda(
const
int
output_lin_kv_dim
=
2
*
embed_dim
;
const
int
output_lin_kv_dim
=
2
*
embed_dim
;
const
int
attn_batches
=
heads
*
sequences
;
const
int
attn_batches
=
heads
*
sequences
;
const
int
lead_dim_q
=
attn_batches
*
head_dim
;
const
int
lead_dim_q
=
attn_batches
*
head_dim
;
const
int
lead_dim_kv
=
attn_batches
*
2
*
head_dim
;
const
int
lead_dim_kv
=
attn_batches
*
2
*
head_dim
;
const
int
batch_stride_q
=
head_dim
;
const
int
batch_stride_q
=
head_dim
;
const
int
batch_stride_kv
=
2
*
head_dim
;
const
int
batch_stride_kv
=
2
*
head_dim
;
const
int
dropout_elems
=
attn_batches
*
q_seq_len
*
k_seq_len
;
const
int
dropout_elems
=
attn_batches
*
q_seq_len
*
k_seq_len
;
...
@@ -62,25 +53,34 @@ std::vector<torch::Tensor> fwd_cuda(
...
@@ -62,25 +53,34 @@ std::vector<torch::Tensor> fwd_cuda(
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
cublasSetStream
(
handle
,
stream
);
cublasSetStream
(
handle
,
stream
);
// 3 Intermediate Results + Output (Note: dropout intermediates are generated by ATen library code)
// 3 Intermediate Results + Output (Note: dropout intermediates are generated
// by ATen library code)
auto
act_options
=
inputs_q
.
options
().
requires_grad
(
false
);
auto
act_options
=
inputs_q
.
options
().
requires_grad
(
false
);
auto
mask_options
=
act_options
.
dtype
(
torch
::
kUInt8
);
auto
mask_options
=
act_options
.
dtype
(
torch
::
kUInt8
);
torch
::
Tensor
input_lin_q_results
=
torch
::
empty
({
q_seq_len
,
sequences
,
output_lin_q_dim
},
act_options
);
torch
::
Tensor
input_lin_q_results
=
torch
::
Tensor
input_lin_kv_results
=
torch
::
empty
({
k_seq_len
,
sequences
,
output_lin_kv_dim
},
act_options
);
torch
::
empty
({
q_seq_len
,
sequences
,
output_lin_q_dim
},
act_options
);
torch
::
Tensor
softmax_results
=
torch
::
empty
({
attn_batches
,
q_seq_len
,
k_seq_len
},
act_options
);
torch
::
Tensor
input_lin_kv_results
=
torch
::
Tensor
dropout_results
=
torch
::
empty
({
attn_batches
,
q_seq_len
,
k_seq_len
},
act_options
);
torch
::
empty
({
k_seq_len
,
sequences
,
output_lin_kv_dim
},
act_options
);
torch
::
Tensor
dropout_mask
=
torch
::
empty
({
attn_batches
,
q_seq_len
,
k_seq_len
},
mask_options
);
torch
::
Tensor
softmax_results
=
torch
::
Tensor
matmul2_results
=
torch
::
empty
({
q_seq_len
,
attn_batches
,
head_dim
},
act_options
);
torch
::
empty
({
attn_batches
,
q_seq_len
,
k_seq_len
},
act_options
);
torch
::
Tensor
dropout_results
=
torch
::
empty
({
attn_batches
,
q_seq_len
,
k_seq_len
},
act_options
);
torch
::
Tensor
dropout_mask
=
torch
::
empty
({
attn_batches
,
q_seq_len
,
k_seq_len
},
mask_options
);
torch
::
Tensor
matmul2_results
=
torch
::
empty
({
q_seq_len
,
attn_batches
,
head_dim
},
act_options
);
torch
::
Tensor
outputs
=
torch
::
empty_like
(
inputs_q
,
act_options
);
torch
::
Tensor
outputs
=
torch
::
empty_like
(
inputs_q
,
act_options
);
// Input Linear Results Pointers to Q, K, and V of interviewed activations
// Input Linear Results Pointers to Q, K, and V of interviewed activations
void
*
q_lin_results_ptr
=
static_cast
<
void
*>
(
input_lin_q_results
.
data_ptr
());
void
*
q_lin_results_ptr
=
static_cast
<
void
*>
(
input_lin_q_results
.
data_ptr
());
void
*
k_lin_results_ptr
=
static_cast
<
void
*>
(
input_lin_kv_results
.
data_ptr
());
void
*
k_lin_results_ptr
=
void
*
v_lin_results_ptr
=
static_cast
<
void
*>
(
static_cast
<
half
*>
(
input_lin_kv_results
.
data_ptr
())
+
head_dim
);
static_cast
<
void
*>
(
input_lin_kv_results
.
data_ptr
());
void
*
v_lin_results_ptr
=
static_cast
<
void
*>
(
static_cast
<
half
*>
(
input_lin_kv_results
.
data_ptr
())
+
head_dim
);
// Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax)
// Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax)
void
*
softmax_results_ptr
=
static_cast
<
void
*>
(
softmax_results
.
data_ptr
());
void
*
softmax_results_ptr
=
static_cast
<
void
*>
(
softmax_results
.
data_ptr
());
char
a_layout_t
{
't'
};
char
a_layout_t
{
't'
};
char
a_layout_n
{
'n'
};
char
a_layout_n
{
'n'
};
...
@@ -140,8 +140,7 @@ std::vector<torch::Tensor> fwd_cuda(
...
@@ -140,8 +140,7 @@ std::vector<torch::Tensor> fwd_cuda(
flags
));
flags
));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum
(
state
,
gemm_switch_fp32accum
(
a_layout_t
,
a_layout_t
,
b_layout_n
,
b_layout_n
,
k_seq_len
,
k_seq_len
,
q_seq_len
,
q_seq_len
,
...
@@ -166,46 +165,35 @@ std::vector<torch::Tensor> fwd_cuda(
...
@@ -166,46 +165,35 @@ std::vector<torch::Tensor> fwd_cuda(
bool
softmax_success
=
false
;
bool
softmax_success
=
false
;
if
(
pad_mask
==
nullptr
)
{
if
(
pad_mask
==
nullptr
)
{
softmax_success
=
dispatch_softmax
<
half
,
half
,
float
>
(
softmax_success
=
dispatch_softmax
<
half
,
half
,
float
>
(
reinterpret_cast
<
half
*>
(
softmax_results_ptr
),
reinterpret_cast
<
half
*>
(
softmax_results_ptr
),
reinterpret_cast
<
const
half
*>
(
softmax_results_ptr
),
reinterpret_cast
<
const
half
*>
(
softmax_results_ptr
),
k_seq_len
,
k_seq_len
,
k_seq_len
,
attn_batches
*
q_seq_len
);
k_seq_len
,
attn_batches
*
q_seq_len
);
}
else
{
}
else
{
if
(
use_time_mask
)
{
if
(
use_time_mask
)
{
softmax_success
=
dispatch_time_masked_softmax
<
half
,
half
,
float
>
(
softmax_success
=
dispatch_time_masked_softmax
<
half
,
half
,
float
>
(
reinterpret_cast
<
half
*>
(
softmax_results_ptr
),
reinterpret_cast
<
half
*>
(
softmax_results_ptr
),
reinterpret_cast
<
const
half
*>
(
softmax_results_ptr
),
reinterpret_cast
<
const
half
*>
(
softmax_results_ptr
),
pad_mask
,
pad_mask
,
k_seq_len
,
k_seq_len
,
attn_batches
*
q_seq_len
,
q_seq_len
);
k_seq_len
,
k_seq_len
,
attn_batches
*
q_seq_len
,
q_seq_len
);
}
else
{
}
else
{
softmax_success
=
dispatch_masked_softmax
<
half
,
half
,
float
>
(
softmax_success
=
dispatch_masked_softmax
<
half
,
half
,
float
>
(
reinterpret_cast
<
half
*>
(
softmax_results_ptr
),
reinterpret_cast
<
half
*>
(
softmax_results_ptr
),
reinterpret_cast
<
const
half
*>
(
softmax_results_ptr
),
reinterpret_cast
<
const
half
*>
(
softmax_results_ptr
),
pad_mask
,
pad_mask
,
k_seq_len
,
k_seq_len
,
attn_batches
*
q_seq_len
,
k_seq_len
,
attn_batches
*
q_seq_len
/
sequences
);
k_seq_len
,
attn_batches
*
q_seq_len
,
attn_batches
*
q_seq_len
/
sequences
);
}
}
}
}
assert
(
softmax_success
);
assert
(
softmax_success
);
if
(
is_training
)
{
if
(
is_training
)
{
apex_fused_dropout_cuda
<
at
::
Half
,
float
,
uint32_t
>
(
apex_fused_dropout_cuda
<
at
::
Half
,
float
,
uint32_t
>
(
static_cast
<
at
::
Half
const
*>
(
softmax_results
.
data_ptr
()),
static_cast
<
at
::
Half
const
*>
(
softmax_results
.
data_ptr
()),
static_cast
<
at
::
Half
*>
(
dropout_results
.
data_ptr
()),
static_cast
<
at
::
Half
*>
(
dropout_results
.
data_ptr
()),
static_cast
<
uint8_t
*>
(
dropout_mask
.
data_ptr
()),
static_cast
<
uint8_t
*>
(
dropout_mask
.
data_ptr
()),
dropout_elems
,
dropout_elems
,
(
1.0
f
-
dropout_prob
));
(
1.0
f
-
dropout_prob
));
}
}
// Matmul2
// Matmul2
gemm_switch_fp32accum
(
state
,
gemm_switch_fp32accum
(
a_layout_n
,
a_layout_n
,
b_layout_n
,
b_layout_n
,
head_dim
,
head_dim
,
q_seq_len
,
q_seq_len
,
...
@@ -253,34 +241,24 @@ std::vector<torch::Tensor> fwd_cuda(
...
@@ -253,34 +241,24 @@ std::vector<torch::Tensor> fwd_cuda(
flags
));
flags
));
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return
{
return
{
input_lin_q_results
,
input_lin_q_results
,
input_lin_kv_results
,
input_lin_kv_results
,
softmax_results
,
softmax_results
,
dropout_results
,
dropout_results
,
dropout_mask
,
dropout_mask
,
matmul2_results
,
matmul2_results
,
outputs
outputs
};
};
}
}
std
::
vector
<
torch
::
Tensor
>
bwd_cuda
(
std
::
vector
<
torch
::
Tensor
>
bwd_cuda
(
int
heads
,
int
heads
,
torch
::
Tensor
const
&
output_grads
,
torch
::
Tensor
const
&
output_grads
,
torch
::
Tensor
const
&
matmul2_results
,
torch
::
Tensor
const
&
dropout_results
,
torch
::
Tensor
const
&
matmul2_results
,
torch
::
Tensor
const
&
softmax_results
,
torch
::
Tensor
const
&
dropout_results
,
torch
::
Tensor
const
&
input_lin_q_results
,
torch
::
Tensor
const
&
softmax_results
,
torch
::
Tensor
const
&
input_lin_kv_results
,
torch
::
Tensor
const
&
inputs_q
,
torch
::
Tensor
const
&
input_lin_q_results
,
torch
::
Tensor
const
&
inputs_kv
,
torch
::
Tensor
const
&
input_weights_q
,
torch
::
Tensor
const
&
input_lin_kv_results
,
torch
::
Tensor
const
&
input_weights_kv
,
torch
::
Tensor
const
&
output_weights
,
torch
::
Tensor
const
&
inputs_q
,
torch
::
Tensor
const
&
dropout_mask
,
float
dropout_prob
)
{
torch
::
Tensor
const
&
inputs_kv
,
torch
::
Tensor
const
&
input_weights_q
,
torch
::
Tensor
const
&
input_weights_kv
,
torch
::
Tensor
const
&
output_weights
,
torch
::
Tensor
const
&
dropout_mask
,
float
dropout_prob
)
{
const
int
embed_dim
=
inputs_q
.
size
(
2
);
const
int
embed_dim
=
inputs_q
.
size
(
2
);
const
int
sequences
=
inputs_q
.
size
(
1
);
const
int
sequences
=
inputs_q
.
size
(
1
);
const
int
q_seq_len
=
inputs_q
.
size
(
0
);
const
int
q_seq_len
=
inputs_q
.
size
(
0
);
...
@@ -292,7 +270,7 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -292,7 +270,7 @@ std::vector<torch::Tensor> bwd_cuda(
const
int
output_lin_kv_dim
=
2
*
embed_dim
;
const
int
output_lin_kv_dim
=
2
*
embed_dim
;
const
int
attn_batches
=
heads
*
sequences
;
const
int
attn_batches
=
heads
*
sequences
;
const
int
lead_dim_q
=
attn_batches
*
head_dim
;
const
int
lead_dim_q
=
attn_batches
*
head_dim
;
const
int
lead_dim_kv
=
attn_batches
*
2
*
head_dim
;
const
int
lead_dim_kv
=
attn_batches
*
2
*
head_dim
;
const
int
batch_stride_q
=
head_dim
;
const
int
batch_stride_q
=
head_dim
;
const
int
batch_stride_kv
=
2
*
head_dim
;
const
int
batch_stride_kv
=
2
*
head_dim
;
const
int
dropout_elems
=
attn_batches
*
q_seq_len
*
k_seq_len
;
const
int
dropout_elems
=
attn_batches
*
q_seq_len
*
k_seq_len
;
...
@@ -316,15 +294,20 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -316,15 +294,20 @@ std::vector<torch::Tensor> bwd_cuda(
at
::
Tensor
output_lin_grads
=
torch
::
empty_like
(
matmul2_results
);
at
::
Tensor
output_lin_grads
=
torch
::
empty_like
(
matmul2_results
);
at
::
Tensor
matmul2_grads
=
torch
::
empty_like
(
dropout_results
);
at
::
Tensor
matmul2_grads
=
torch
::
empty_like
(
dropout_results
);
at
::
Tensor
input_lin_q_output_grads
=
torch
::
empty_like
(
input_lin_q_results
);
at
::
Tensor
input_lin_q_output_grads
=
torch
::
empty_like
(
input_lin_q_results
);
at
::
Tensor
input_lin_kv_output_grads
=
torch
::
empty_like
(
input_lin_kv_results
);
at
::
Tensor
input_lin_kv_output_grads
=
torch
::
empty_like
(
input_lin_kv_results
);
auto
q_lin_results_ptr
=
static_cast
<
half
*>
(
input_lin_q_results
.
data_ptr
());
auto
q_lin_results_ptr
=
static_cast
<
half
*>
(
input_lin_q_results
.
data_ptr
());
auto
k_lin_results_ptr
=
static_cast
<
half
*>
(
input_lin_kv_results
.
data_ptr
());
auto
k_lin_results_ptr
=
static_cast
<
half
*>
(
input_lin_kv_results
.
data_ptr
());
auto
v_lin_results_ptr
=
static_cast
<
half
*>
(
input_lin_kv_results
.
data_ptr
())
+
head_dim
;
auto
v_lin_results_ptr
=
static_cast
<
half
*>
(
input_lin_kv_results
.
data_ptr
())
+
head_dim
;
auto
q_lin_grads_ptr
=
static_cast
<
half
*>
(
input_lin_q_output_grads
.
data_ptr
());
auto
q_lin_grads_ptr
=
auto
k_lin_grads_ptr
=
static_cast
<
half
*>
(
input_lin_kv_output_grads
.
data_ptr
());
static_cast
<
half
*>
(
input_lin_q_output_grads
.
data_ptr
());
auto
v_lin_grads_ptr
=
static_cast
<
half
*>
(
input_lin_kv_output_grads
.
data_ptr
())
+
head_dim
;
auto
k_lin_grads_ptr
=
static_cast
<
half
*>
(
input_lin_kv_output_grads
.
data_ptr
());
auto
v_lin_grads_ptr
=
static_cast
<
half
*>
(
input_lin_kv_output_grads
.
data_ptr
())
+
head_dim
;
char
a_layout_n
{
'n'
};
char
a_layout_n
{
'n'
};
char
a_layout_t
{
't'
};
char
a_layout_t
{
't'
};
...
@@ -386,8 +369,7 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -386,8 +369,7 @@ std::vector<torch::Tensor> bwd_cuda(
flags
));
flags
));
// MatMul2 Dgrad1
// MatMul2 Dgrad1
gemm_switch_fp32accum
(
state
,
gemm_switch_fp32accum
(
a_layout_t
,
a_layout_t
,
b_layout_n
,
b_layout_n
,
k_seq_len
,
k_seq_len
,
q_seq_len
,
q_seq_len
,
...
@@ -409,8 +391,7 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -409,8 +391,7 @@ std::vector<torch::Tensor> bwd_cuda(
attn_batches
);
attn_batches
);
// Matmul2 Dgrad2
// Matmul2 Dgrad2
gemm_switch_fp32accum
(
state
,
gemm_switch_fp32accum
(
a_layout_n
,
a_layout_n
,
b_layout_t
,
b_layout_t
,
head_dim
,
head_dim
,
k_seq_len
,
k_seq_len
,
...
@@ -442,17 +423,14 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -442,17 +423,14 @@ std::vector<torch::Tensor> bwd_cuda(
// Softmax Grad
// Softmax Grad
bool
softmax_success
=
false
;
bool
softmax_success
=
false
;
softmax_success
=
dispatch_softmax_backward
<
half
,
half
,
float
>
(
softmax_success
=
dispatch_softmax_backward
<
half
,
half
,
float
>
(
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
reinterpret_cast
<
half
const
*>
(
softmax_results
.
data_ptr
()),
reinterpret_cast
<
half
const
*>
(
softmax_results
.
data_ptr
()),
k_seq_len
,
k_seq_len
,
k_seq_len
,
attn_batches
*
q_seq_len
);
k_seq_len
,
attn_batches
*
q_seq_len
);
assert
(
softmax_success
);
assert
(
softmax_success
);
// Matmul1 Dgrad1
// Matmul1 Dgrad1
gemm_switch_fp32accum
(
state
,
gemm_switch_fp32accum
(
a_layout_n
,
a_layout_n
,
b_layout_n
,
b_layout_n
,
head_dim
,
head_dim
,
q_seq_len
,
q_seq_len
,
...
@@ -474,8 +452,7 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -474,8 +452,7 @@ std::vector<torch::Tensor> bwd_cuda(
attn_batches
);
attn_batches
);
// Matmul1 Dgrad2
// Matmul1 Dgrad2
gemm_switch_fp32accum
(
state
,
gemm_switch_fp32accum
(
a_layout_n
,
a_layout_n
,
b_layout_t
,
b_layout_t
,
head_dim
,
head_dim
,
k_seq_len
,
k_seq_len
,
...
@@ -612,3 +589,4 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -612,3 +589,4 @@ std::vector<torch::Tensor> bwd_cuda(
}
// end namespace rocblas_gemmex
}
// end namespace rocblas_gemmex
}
// end namespace encdec
}
// end namespace encdec
}
// end namespace multihead_attn
}
// end namespace multihead_attn
apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add_cpp.cpp
View file @
db92ee13
...
@@ -5,66 +5,49 @@ namespace multihead_attn {
...
@@ -5,66 +5,49 @@ namespace multihead_attn {
namespace
encdec_norm_add
{
namespace
encdec_norm_add
{
namespace
rocblas_gemmex
{
namespace
rocblas_gemmex
{
std
::
vector
<
torch
::
Tensor
>
fwd_cuda
(
std
::
vector
<
torch
::
Tensor
>
fwd_cuda
(
bool
use_time_mask
,
bool
is_training
,
bool
use_time_mask
,
int
heads
,
torch
::
Tensor
const
&
inputs_q
,
bool
is_training
,
torch
::
Tensor
const
&
inputs_kv
,
int
heads
,
torch
::
Tensor
const
&
lyr_nrm_gamma_weights
,
torch
::
Tensor
const
&
inputs_q
,
torch
::
Tensor
const
&
lyr_nrm_beta_weights
,
torch
::
Tensor
const
&
inputs_kv
,
torch
::
Tensor
const
&
input_weights_q
,
torch
::
Tensor
const
&
lyr_nrm_gamma_weights
,
torch
::
Tensor
const
&
input_weights_kv
,
torch
::
Tensor
const
&
lyr_nrm_beta_weights
,
torch
::
Tensor
const
&
output_weights
,
torch
::
Tensor
const
&
input_weights_q
,
const
uint8_t
*
pad_mask
,
torch
::
Tensor
const
&
input_weights_kv
,
float
dropout_prob
);
torch
::
Tensor
const
&
output_weights
,
const
uint8_t
*
pad_mask
,
float
dropout_prob
);
std
::
vector
<
torch
::
Tensor
>
bwd_cuda
(
std
::
vector
<
torch
::
Tensor
>
bwd_cuda
(
int
heads
,
int
heads
,
torch
::
Tensor
const
&
output_grads
,
torch
::
Tensor
const
&
output_grads
,
torch
::
Tensor
const
&
matmul2_results
,
torch
::
Tensor
const
&
dropout_results
,
torch
::
Tensor
const
&
matmul2_results
,
torch
::
Tensor
const
&
softmax_results
,
torch
::
Tensor
const
&
dropout_results
,
torch
::
Tensor
const
&
input_lin_q_results
,
torch
::
Tensor
const
&
softmax_results
,
torch
::
Tensor
const
&
input_lin_kv_results
,
torch
::
Tensor
const
&
input_lin_q_results
,
torch
::
Tensor
const
&
lyr_nrm_results
,
torch
::
Tensor
const
&
lyr_nrm_mean
,
torch
::
Tensor
const
&
input_lin_kv_results
,
torch
::
Tensor
const
&
lyr_nrm_invvar
,
torch
::
Tensor
const
&
inputs_q
,
torch
::
Tensor
const
&
lyr_nrm_results
,
torch
::
Tensor
const
&
inputs_kv
,
torch
::
Tensor
const
&
lyr_nrm_gamma_weights
,
torch
::
Tensor
const
&
lyr_nrm_mean
,
torch
::
Tensor
const
&
lyr_nrm_beta_weights
,
torch
::
Tensor
const
&
lyr_nrm_invvar
,
torch
::
Tensor
const
&
input_weights_q
,
torch
::
Tensor
const
&
input_weights_kv
,
torch
::
Tensor
const
&
inputs_q
,
torch
::
Tensor
const
&
output_weights
,
torch
::
Tensor
const
&
dropout_mask
,
torch
::
Tensor
const
&
inputs_kv
,
torch
::
Tensor
const
&
dropout_add_mask
,
float
dropout_prob
);
torch
::
Tensor
const
&
lyr_nrm_gamma_weights
,
torch
::
Tensor
const
&
lyr_nrm_beta_weights
,
torch
::
Tensor
const
&
input_weights_q
,
torch
::
Tensor
const
&
input_weights_kv
,
torch
::
Tensor
const
&
output_weights
,
torch
::
Tensor
const
&
dropout_mask
,
torch
::
Tensor
const
&
dropout_add_mask
,
float
dropout_prob
);
// C++ interface
// C++ interface
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CUDA(x) \
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
#define CHECK_CONTIGUOUS(x) \
AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
std
::
vector
<
torch
::
Tensor
>
fwd
(
std
::
vector
<
torch
::
Tensor
>
bool
use_mask
,
fwd
(
bool
use_mask
,
bool
use_time_mask
,
bool
is_training
,
int
heads
,
bool
use_time_mask
,
torch
::
Tensor
const
&
inputs_q
,
torch
::
Tensor
const
&
inputs_kv
,
bool
is_training
,
torch
::
Tensor
const
&
lyr_nrm_gamma_weights
,
int
heads
,
torch
::
Tensor
const
&
lyr_nrm_beta_weights
,
torch
::
Tensor
const
&
inputs_q
,
torch
::
Tensor
const
&
input_weights_q
,
torch
::
Tensor
const
&
input_weights_kv
,
torch
::
Tensor
const
&
inputs_kv
,
torch
::
Tensor
const
&
output_weights
,
torch
::
Tensor
const
&
pad_mask
,
torch
::
Tensor
const
&
lyr_nrm_gamma_weights
,
float
dropout_prob
)
{
torch
::
Tensor
const
&
lyr_nrm_beta_weights
,
torch
::
Tensor
const
&
input_weights_q
,
torch
::
Tensor
const
&
input_weights_kv
,
torch
::
Tensor
const
&
output_weights
,
torch
::
Tensor
const
&
pad_mask
,
float
dropout_prob
)
{
AT_ASSERTM
(
inputs_q
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
inputs_q
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
inputs_kv
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
inputs_kv
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
lyr_nrm_gamma_weights
.
dim
()
==
1
,
"expected 1D tensor"
);
AT_ASSERTM
(
lyr_nrm_gamma_weights
.
dim
()
==
1
,
"expected 1D tensor"
);
...
@@ -73,58 +56,48 @@ std::vector<torch::Tensor> fwd(
...
@@ -73,58 +56,48 @@ std::vector<torch::Tensor> fwd(
AT_ASSERTM
(
input_weights_kv
.
dim
()
==
2
,
"expected 2D tensor"
);
AT_ASSERTM
(
input_weights_kv
.
dim
()
==
2
,
"expected 2D tensor"
);
AT_ASSERTM
(
output_weights
.
dim
()
==
2
,
"expected 2D tensor"
);
AT_ASSERTM
(
output_weights
.
dim
()
==
2
,
"expected 2D tensor"
);
AT_ASSERTM
(
inputs_q
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
inputs_q
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
AT_ASSERTM
(
inputs_kv
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
"Only HALF is supported"
);
AT_ASSERTM
(
lyr_nrm_gamma_weights
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
inputs_kv
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
AT_ASSERTM
(
lyr_nrm_beta_weights
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
"Only HALF is supported"
);
AT_ASSERTM
(
input_weights_q
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
lyr_nrm_gamma_weights
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
AT_ASSERTM
(
input_weights_kv
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
"Only HALF is supported"
);
AT_ASSERTM
(
output_weights
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
lyr_nrm_beta_weights
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
input_weights_q
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
input_weights_kv
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
output_weights
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
if
(
use_mask
)
{
if
(
use_mask
)
{
AT_ASSERTM
(
pad_mask
.
dim
()
==
2
,
"expected 2D tensor"
);
AT_ASSERTM
(
pad_mask
.
dim
()
==
2
,
"expected 2D tensor"
);
AT_ASSERTM
(
pad_mask
.
type
().
scalarType
()
==
at
::
ScalarType
::
Byte
,
"Only BYTE is supported"
);
AT_ASSERTM
(
pad_mask
.
type
().
scalarType
()
==
at
::
ScalarType
::
Byte
,
"Only BYTE is supported"
);
}
}
return
fwd_cuda
(
return
fwd_cuda
(
use_time_mask
,
is_training
,
heads
,
inputs_q
,
inputs_kv
,
use_time_mask
,
lyr_nrm_gamma_weights
,
lyr_nrm_beta_weights
,
input_weights_q
,
is_training
,
input_weights_kv
,
output_weights
,
heads
,
use_mask
?
static_cast
<
const
uint8_t
*>
(
pad_mask
.
data_ptr
())
inputs_q
,
:
nullptr
,
inputs_kv
,
dropout_prob
);
lyr_nrm_gamma_weights
,
lyr_nrm_beta_weights
,
input_weights_q
,
input_weights_kv
,
output_weights
,
use_mask
?
static_cast
<
const
uint8_t
*>
(
pad_mask
.
data_ptr
())
:
nullptr
,
dropout_prob
);
}
}
std
::
vector
<
torch
::
Tensor
>
bwd
(
std
::
vector
<
torch
::
Tensor
>
int
heads
,
bwd
(
int
heads
,
torch
::
Tensor
const
&
output_grads
,
torch
::
Tensor
const
&
output_grads
,
torch
::
Tensor
const
&
matmul2_results
,
torch
::
Tensor
const
&
dropout_results
,
torch
::
Tensor
const
&
matmul2_results
,
torch
::
Tensor
const
&
softmax_results
,
torch
::
Tensor
const
&
dropout_results
,
torch
::
Tensor
const
&
input_lin_q_results
,
torch
::
Tensor
const
&
softmax_results
,
torch
::
Tensor
const
&
input_lin_kv_results
,
torch
::
Tensor
const
&
input_lin_q_results
,
torch
::
Tensor
const
&
lyr_nrm_results
,
torch
::
Tensor
const
&
lyr_nrm_mean
,
torch
::
Tensor
const
&
input_lin_kv_results
,
torch
::
Tensor
const
&
lyr_nrm_invvar
,
torch
::
Tensor
const
&
inputs_q
,
torch
::
Tensor
const
&
lyr_nrm_results
,
torch
::
Tensor
const
&
inputs_kv
,
torch
::
Tensor
const
&
lyr_nrm_gamma_weights
,
torch
::
Tensor
const
&
lyr_nrm_mean
,
torch
::
Tensor
const
&
lyr_nrm_beta_weights
,
torch
::
Tensor
const
&
lyr_nrm_invvar
,
torch
::
Tensor
const
&
input_weights_q
,
torch
::
Tensor
const
&
input_weights_kv
,
torch
::
Tensor
const
&
inputs_q
,
torch
::
Tensor
const
&
output_weights
,
torch
::
Tensor
const
&
dropout_mask
,
torch
::
Tensor
const
&
inputs_kv
,
torch
::
Tensor
const
&
dropout_add_mask
,
float
dropout_prob
)
{
torch
::
Tensor
const
&
lyr_nrm_gamma_weights
,
torch
::
Tensor
const
&
lyr_nrm_beta_weights
,
torch
::
Tensor
const
&
input_weights_q
,
torch
::
Tensor
const
&
input_weights_kv
,
torch
::
Tensor
const
&
output_weights
,
torch
::
Tensor
const
&
dropout_mask
,
torch
::
Tensor
const
&
dropout_add_mask
,
float
dropout_prob
)
{
AT_ASSERTM
(
output_grads
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
output_grads
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
matmul2_results
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
matmul2_results
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
dropout_results
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
dropout_results
.
dim
()
==
3
,
"expected 3D tensor"
);
...
@@ -144,47 +117,49 @@ std::vector<torch::Tensor> bwd(
...
@@ -144,47 +117,49 @@ std::vector<torch::Tensor> bwd(
AT_ASSERTM
(
dropout_mask
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
dropout_mask
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
dropout_add_mask
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
dropout_add_mask
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
output_grads
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
output_grads
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
AT_ASSERTM
(
matmul2_results
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
"Only HALF is supported"
);
AT_ASSERTM
(
dropout_results
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
matmul2_results
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
AT_ASSERTM
(
softmax_results
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
"Only HALF is supported"
);
AT_ASSERTM
(
input_lin_q_results
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
dropout_results
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
AT_ASSERTM
(
input_lin_kv_results
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
"Only HALF is supported"
);
AT_ASSERTM
(
lyr_nrm_results
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
softmax_results
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
AT_ASSERTM
(
lyr_nrm_mean
.
type
().
scalarType
()
==
at
::
ScalarType
::
Float
,
"Only FLOAT is supported"
);
"Only HALF is supported"
);
AT_ASSERTM
(
lyr_nrm_invvar
.
type
().
scalarType
()
==
at
::
ScalarType
::
Float
,
"Only FLOAT is supported"
);
AT_ASSERTM
(
input_lin_q_results
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
AT_ASSERTM
(
inputs_q
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
"Only HALF is supported"
);
AT_ASSERTM
(
inputs_kv
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
input_lin_kv_results
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
AT_ASSERTM
(
lyr_nrm_gamma_weights
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
"Only HALF is supported"
);
AT_ASSERTM
(
lyr_nrm_beta_weights
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
lyr_nrm_results
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
AT_ASSERTM
(
input_weights_q
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
"Only HALF is supported"
);
AT_ASSERTM
(
input_weights_kv
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
lyr_nrm_mean
.
type
().
scalarType
()
==
at
::
ScalarType
::
Float
,
AT_ASSERTM
(
output_weights
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
"Only FLOAT is supported"
);
AT_ASSERTM
(
dropout_mask
.
type
().
scalarType
()
==
at
::
ScalarType
::
Byte
,
"Only BYTE is supported"
);
AT_ASSERTM
(
lyr_nrm_invvar
.
type
().
scalarType
()
==
at
::
ScalarType
::
Float
,
AT_ASSERTM
(
dropout_add_mask
.
type
().
scalarType
()
==
at
::
ScalarType
::
Byte
,
"Only BYTE is supported"
);
"Only FLOAT is supported"
);
AT_ASSERTM
(
inputs_q
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
inputs_kv
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
lyr_nrm_gamma_weights
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
lyr_nrm_beta_weights
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
input_weights_q
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
input_weights_kv
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
output_weights
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
dropout_mask
.
type
().
scalarType
()
==
at
::
ScalarType
::
Byte
,
"Only BYTE is supported"
);
AT_ASSERTM
(
dropout_add_mask
.
type
().
scalarType
()
==
at
::
ScalarType
::
Byte
,
"Only BYTE is supported"
);
return
bwd_cuda
(
return
bwd_cuda
(
heads
,
output_grads
,
matmul2_results
,
dropout_results
,
heads
,
softmax_results
,
input_lin_q_results
,
input_lin_kv_results
,
output_grads
,
lyr_nrm_results
,
lyr_nrm_mean
,
lyr_nrm_invvar
,
inputs_q
,
matmul2_results
,
inputs_kv
,
lyr_nrm_gamma_weights
,
lyr_nrm_beta_weights
,
dropout_results
,
input_weights_q
,
input_weights_kv
,
output_weights
,
softmax_results
,
dropout_mask
,
dropout_add_mask
,
dropout_prob
);
input_lin_q_results
,
input_lin_kv_results
,
lyr_nrm_results
,
lyr_nrm_mean
,
lyr_nrm_invvar
,
inputs_q
,
inputs_kv
,
lyr_nrm_gamma_weights
,
lyr_nrm_beta_weights
,
input_weights_q
,
input_weights_kv
,
output_weights
,
dropout_mask
,
dropout_add_mask
,
dropout_prob
);
}
}
}
// end namespace cublas_gemmex
}
// end namespace cublas_gemmex
...
@@ -195,4 +170,3 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
...
@@ -195,4 +170,3 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m
.
def
(
"forward"
,
&
multihead_attn
::
encdec_norm_add
::
rocblas_gemmex
::
fwd
,
"Encdec Multihead Attention Plus Layer Norm and Residual Add Forward."
);
m
.
def
(
"forward"
,
&
multihead_attn
::
encdec_norm_add
::
rocblas_gemmex
::
fwd
,
"Encdec Multihead Attention Plus Layer Norm and Residual Add Forward."
);
m
.
def
(
"backward"
,
&
multihead_attn
::
encdec_norm_add
::
rocblas_gemmex
::
bwd
,
"Encdec Multihead Attention Plus Layer Norm and Residual Add Backward."
);
m
.
def
(
"backward"
,
&
multihead_attn
::
encdec_norm_add
::
rocblas_gemmex
::
bwd
,
"Encdec Multihead Attention Plus Layer Norm and Residual Add Backward."
);
}
}
apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add_cuda.cu
View file @
db92ee13
#include <vector>
#include <math.h>
#include <iostream>
#include <iostream>
#include <math.h>
#include <vector>
#include <cuda.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_fp16.h>
//#include <cuda_profiler_api.h>
//#include <cuda_profiler_api.h>
#include <cuda_runtime.h>
#include <ATen/ATen.h>
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include <torch/extension.h>
#include "strided_batched_gemm.h"
#include "softmax.h"
#include "dropout.h"
#include "dropout.h"
#include "layer_norm.h"
#include "layer_norm.h"
#include "softmax.h"
// symbol to be automatically resolved by PyTorch libs
#include "strided_batched_gemm.h"
extern
THCState
*
state
;
namespace
multihead_attn
{
namespace
multihead_attn
{
namespace
encdec_norm_add
{
namespace
encdec_norm_add
{
...
@@ -64,7 +61,8 @@ std::vector<torch::Tensor> fwd_cuda(
...
@@ -64,7 +61,8 @@ std::vector<torch::Tensor> fwd_cuda(
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
cublasSetStream
(
handle
,
stream
);
cublasSetStream
(
handle
,
stream
);
// 3 Intermediate Results + Output (Note: dropout intermediates are generated by ATen library code)
// 3 Intermediate Results + Output (Note: dropout intermediates are generated
// by ATen library code)
auto
act_options
=
inputs_q
.
options
().
requires_grad
(
false
);
auto
act_options
=
inputs_q
.
options
().
requires_grad
(
false
);
auto
lyr_nrm_options
=
act_options
.
dtype
(
torch
::
kFloat32
);
auto
lyr_nrm_options
=
act_options
.
dtype
(
torch
::
kFloat32
);
auto
mask_options
=
act_options
.
dtype
(
torch
::
kUInt8
);
auto
mask_options
=
act_options
.
dtype
(
torch
::
kUInt8
);
...
@@ -73,23 +71,31 @@ std::vector<torch::Tensor> fwd_cuda(
...
@@ -73,23 +71,31 @@ std::vector<torch::Tensor> fwd_cuda(
torch
::
Tensor
lyr_nrm_invvar
=
torch
::
empty
({
batches_q
},
lyr_nrm_options
);
torch
::
Tensor
lyr_nrm_invvar
=
torch
::
empty
({
batches_q
},
lyr_nrm_options
);
torch
::
Tensor
lyr_nrm_results
=
torch
::
empty_like
(
inputs_q
,
act_options
);
torch
::
Tensor
lyr_nrm_results
=
torch
::
empty_like
(
inputs_q
,
act_options
);
torch
::
Tensor
input_lin_q_results
=
torch
::
empty
({
q_seq_len
,
sequences
,
output_lin_q_dim
},
act_options
);
torch
::
Tensor
input_lin_q_results
=
torch
::
Tensor
input_lin_kv_results
=
torch
::
empty
({
k_seq_len
,
sequences
,
output_lin_kv_dim
},
act_options
);
torch
::
empty
({
q_seq_len
,
sequences
,
output_lin_q_dim
},
act_options
);
torch
::
Tensor
softmax_results
=
torch
::
empty
({
attn_batches
,
q_seq_len
,
k_seq_len
},
act_options
);
torch
::
Tensor
input_lin_kv_results
=
torch
::
Tensor
dropout_results
=
torch
::
empty
({
attn_batches
,
q_seq_len
,
k_seq_len
},
act_options
);
torch
::
empty
({
k_seq_len
,
sequences
,
output_lin_kv_dim
},
act_options
);
torch
::
Tensor
dropout_mask
=
torch
::
empty
({
attn_batches
,
q_seq_len
,
k_seq_len
},
mask_options
);
torch
::
Tensor
softmax_results
=
torch
::
Tensor
matmul2_results
=
torch
::
empty
({
q_seq_len
,
attn_batches
,
head_dim
},
act_options
);
torch
::
empty
({
attn_batches
,
q_seq_len
,
k_seq_len
},
act_options
);
torch
::
Tensor
dropout_results
=
torch
::
empty
({
attn_batches
,
q_seq_len
,
k_seq_len
},
act_options
);
torch
::
Tensor
dropout_mask
=
torch
::
empty
({
attn_batches
,
q_seq_len
,
k_seq_len
},
mask_options
);
torch
::
Tensor
matmul2_results
=
torch
::
empty
({
q_seq_len
,
attn_batches
,
head_dim
},
act_options
);
torch
::
Tensor
output_lin_results
=
torch
::
empty_like
(
inputs_q
,
act_options
);
torch
::
Tensor
output_lin_results
=
torch
::
empty_like
(
inputs_q
,
act_options
);
torch
::
Tensor
dropout_add_mask
=
torch
::
empty_like
(
inputs_q
,
mask_options
);
torch
::
Tensor
dropout_add_mask
=
torch
::
empty_like
(
inputs_q
,
mask_options
);
torch
::
Tensor
outputs
=
torch
::
empty_like
(
inputs_q
,
act_options
);
torch
::
Tensor
outputs
=
torch
::
empty_like
(
inputs_q
,
act_options
);
// Input Linear Results Pointers to Q, K, and V of interviewed activations
// Input Linear Results Pointers to Q, K, and V of interviewed activations
void
*
q_lin_results_ptr
=
static_cast
<
void
*>
(
input_lin_q_results
.
data_ptr
());
void
*
q_lin_results_ptr
=
static_cast
<
void
*>
(
input_lin_q_results
.
data_ptr
());
void
*
k_lin_results_ptr
=
static_cast
<
void
*>
(
input_lin_kv_results
.
data_ptr
());
void
*
k_lin_results_ptr
=
void
*
v_lin_results_ptr
=
static_cast
<
void
*>
(
static_cast
<
half
*>
(
input_lin_kv_results
.
data_ptr
())
+
head_dim
);
static_cast
<
void
*>
(
input_lin_kv_results
.
data_ptr
());
void
*
v_lin_results_ptr
=
static_cast
<
void
*>
(
static_cast
<
half
*>
(
input_lin_kv_results
.
data_ptr
())
+
head_dim
);
// Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax)
// Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax)
void
*
softmax_results_ptr
=
static_cast
<
void
*>
(
softmax_results
.
data_ptr
());
void
*
softmax_results_ptr
=
static_cast
<
void
*>
(
softmax_results
.
data_ptr
());
char
a_layout_t
{
't'
};
char
a_layout_t
{
't'
};
char
a_layout_n
{
'n'
};
char
a_layout_n
{
'n'
};
...
@@ -97,16 +103,15 @@ std::vector<torch::Tensor> fwd_cuda(
...
@@ -97,16 +103,15 @@ std::vector<torch::Tensor> fwd_cuda(
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Layer Norm
// Layer Norm
HostApplyLayerNorm
<
at
::
Half
,
float
>
(
HostApplyLayerNorm
<
at
::
Half
,
float
>
(
static_cast
<
at
::
Half
*>
(
lyr_nrm_results
.
data_ptr
()),
static_cast
<
at
::
Half
*>
(
lyr_nrm_results
.
data_ptr
()),
static_cast
<
float
*>
(
lyr_nrm_mean
.
data_ptr
()),
static_cast
<
float
*>
(
lyr_nrm_mean
.
data_ptr
()),
static_cast
<
float
*>
(
lyr_nrm_invvar
.
data_ptr
()),
static_cast
<
float
*>
(
lyr_nrm_invvar
.
data_ptr
()),
static_cast
<
const
at
::
Half
*>
(
inputs_q
.
data_ptr
()),
static_cast
<
const
at
::
Half
*>
(
inputs_q
.
data_ptr
()),
static_cast
<
int
>
(
batches_q
),
// n1
static_cast
<
int
>
(
batches_q
),
// n1
static_cast
<
int
>
(
embed_dim
),
// n2
static_cast
<
int
>
(
embed_dim
),
// n2
1.0e-5
,
1.0e-5
,
static_cast
<
const
at
::
Half
*>
(
lyr_nrm_gamma_weights
.
data_ptr
()),
static_cast
<
const
at
::
Half
*>
(
lyr_nrm_gamma_weights
.
data_ptr
()),
static_cast
<
const
at
::
Half
*>
(
lyr_nrm_beta_weights
.
data_ptr
()));
static_cast
<
const
at
::
Half
*>
(
lyr_nrm_beta_weights
.
data_ptr
()));
// Input Linear Q Fwd
// Input Linear Q Fwd
TORCH_CUDABLAS_CHECK
(
rocblas_gemm_ex
(
handle
,
TORCH_CUDABLAS_CHECK
(
rocblas_gemm_ex
(
handle
,
...
@@ -161,8 +166,7 @@ std::vector<torch::Tensor> fwd_cuda(
...
@@ -161,8 +166,7 @@ std::vector<torch::Tensor> fwd_cuda(
solution_index
,
solution_index
,
flags
));
flags
));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum
(
state
,
gemm_switch_fp32accum
(
a_layout_t
,
a_layout_t
,
b_layout_n
,
b_layout_n
,
k_seq_len
,
k_seq_len
,
q_seq_len
,
q_seq_len
,
...
@@ -187,46 +191,35 @@ std::vector<torch::Tensor> fwd_cuda(
...
@@ -187,46 +191,35 @@ std::vector<torch::Tensor> fwd_cuda(
bool
softmax_success
=
false
;
bool
softmax_success
=
false
;
if
(
pad_mask
==
nullptr
)
{
if
(
pad_mask
==
nullptr
)
{
softmax_success
=
dispatch_softmax
<
half
,
half
,
float
>
(
softmax_success
=
dispatch_softmax
<
half
,
half
,
float
>
(
reinterpret_cast
<
half
*>
(
softmax_results_ptr
),
reinterpret_cast
<
half
*>
(
softmax_results_ptr
),
reinterpret_cast
<
const
half
*>
(
softmax_results_ptr
),
reinterpret_cast
<
const
half
*>
(
softmax_results_ptr
),
k_seq_len
,
k_seq_len
,
k_seq_len
,
attn_batches
*
q_seq_len
);
k_seq_len
,
attn_batches
*
q_seq_len
);
}
else
{
}
else
{
if
(
use_time_mask
)
{
if
(
use_time_mask
)
{
softmax_success
=
dispatch_time_masked_softmax
<
half
,
half
,
float
>
(
softmax_success
=
dispatch_time_masked_softmax
<
half
,
half
,
float
>
(
reinterpret_cast
<
half
*>
(
softmax_results_ptr
),
reinterpret_cast
<
half
*>
(
softmax_results_ptr
),
reinterpret_cast
<
const
half
*>
(
softmax_results_ptr
),
reinterpret_cast
<
const
half
*>
(
softmax_results_ptr
),
pad_mask
,
pad_mask
,
k_seq_len
,
k_seq_len
,
attn_batches
*
q_seq_len
,
q_seq_len
);
k_seq_len
,
k_seq_len
,
attn_batches
*
q_seq_len
,
q_seq_len
);
}
else
{
}
else
{
softmax_success
=
dispatch_masked_softmax
<
half
,
half
,
float
>
(
softmax_success
=
dispatch_masked_softmax
<
half
,
half
,
float
>
(
reinterpret_cast
<
half
*>
(
softmax_results_ptr
),
reinterpret_cast
<
half
*>
(
softmax_results_ptr
),
reinterpret_cast
<
const
half
*>
(
softmax_results_ptr
),
reinterpret_cast
<
const
half
*>
(
softmax_results_ptr
),
pad_mask
,
pad_mask
,
k_seq_len
,
k_seq_len
,
attn_batches
*
q_seq_len
,
k_seq_len
,
attn_batches
*
q_seq_len
/
sequences
);
k_seq_len
,
attn_batches
*
q_seq_len
,
attn_batches
*
q_seq_len
/
sequences
);
}
}
}
}
assert
(
softmax_success
);
assert
(
softmax_success
);
if
(
is_training
)
{
if
(
is_training
)
{
apex_fused_dropout_cuda
<
at
::
Half
,
float
,
uint32_t
>
(
apex_fused_dropout_cuda
<
at
::
Half
,
float
,
uint32_t
>
(
static_cast
<
at
::
Half
const
*>
(
softmax_results
.
data_ptr
()),
static_cast
<
at
::
Half
const
*>
(
softmax_results
.
data_ptr
()),
static_cast
<
at
::
Half
*>
(
dropout_results
.
data_ptr
()),
static_cast
<
at
::
Half
*>
(
dropout_results
.
data_ptr
()),
static_cast
<
uint8_t
*>
(
dropout_mask
.
data_ptr
()),
static_cast
<
uint8_t
*>
(
dropout_mask
.
data_ptr
()),
dropout_elems
,
dropout_elems
,
(
1.0
f
-
dropout_prob
));
(
1.0
f
-
dropout_prob
));
}
}
// Matmul2
// Matmul2
gemm_switch_fp32accum
(
state
,
gemm_switch_fp32accum
(
a_layout_n
,
a_layout_n
,
b_layout_n
,
b_layout_n
,
head_dim
,
head_dim
,
q_seq_len
,
q_seq_len
,
...
@@ -276,25 +269,22 @@ std::vector<torch::Tensor> fwd_cuda(
...
@@ -276,25 +269,22 @@ std::vector<torch::Tensor> fwd_cuda(
// End-of-block Dropout-Add
// End-of-block Dropout-Add
if
(
is_training
)
{
if
(
is_training
)
{
apex_dropout_add_cuda
<
at
::
Half
,
float
,
uint32_t
>
(
apex_dropout_add_cuda
<
at
::
Half
,
float
,
uint32_t
>
(
static_cast
<
at
::
Half
const
*>
(
output_lin_results
.
data_ptr
()),
static_cast
<
at
::
Half
const
*>
(
output_lin_results
.
data_ptr
()),
static_cast
<
at
::
Half
const
*>
(
inputs_q
.
data_ptr
()),
static_cast
<
at
::
Half
const
*>
(
inputs_q
.
data_ptr
()),
static_cast
<
at
::
Half
*>
(
outputs
.
data_ptr
()),
static_cast
<
at
::
Half
*>
(
outputs
.
data_ptr
()),
static_cast
<
uint8_t
*>
(
dropout_add_mask
.
data_ptr
()),
static_cast
<
uint8_t
*>
(
dropout_add_mask
.
data_ptr
()),
total_tokens_q
,
total_tokens_q
,
(
1.0
f
-
dropout_prob
));
(
1.0
f
-
dropout_prob
));
}
else
{
}
else
{
apex_add_cuda
<
at
::
Half
,
float
,
uint32_t
>
(
apex_add_cuda
<
at
::
Half
,
float
,
uint32_t
>
(
static_cast
<
at
::
Half
const
*>
(
output_lin_results
.
data_ptr
()),
static_cast
<
at
::
Half
const
*>
(
output_lin_results
.
data_ptr
()),
static_cast
<
at
::
Half
const
*>
(
inputs_q
.
data_ptr
()),
static_cast
<
at
::
Half
const
*>
(
inputs_q
.
data_ptr
()),
static_cast
<
at
::
Half
*>
(
outputs
.
data_ptr
()),
static_cast
<
at
::
Half
*>
(
outputs
.
data_ptr
()),
total_tokens_q
);
total_tokens_q
);
}
}
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return
{
return
{
lyr_nrm_results
,
lyr_nrm_results
,
lyr_nrm_mean
,
lyr_nrm_mean
,
lyr_nrm_invvar
,
lyr_nrm_invvar
,
input_lin_q_results
,
input_lin_q_results
,
...
@@ -304,33 +294,22 @@ std::vector<torch::Tensor> fwd_cuda(
...
@@ -304,33 +294,22 @@ std::vector<torch::Tensor> fwd_cuda(
dropout_mask
,
dropout_mask
,
matmul2_results
,
matmul2_results
,
dropout_add_mask
,
dropout_add_mask
,
outputs
outputs
};
};
}
}
std
::
vector
<
torch
::
Tensor
>
bwd_cuda
(
std
::
vector
<
torch
::
Tensor
>
bwd_cuda
(
int
heads
,
int
heads
,
torch
::
Tensor
const
&
output_grads
,
torch
::
Tensor
const
&
output_grads
,
torch
::
Tensor
const
&
matmul2_results
,
torch
::
Tensor
const
&
dropout_results
,
torch
::
Tensor
const
&
matmul2_results
,
torch
::
Tensor
const
&
softmax_results
,
torch
::
Tensor
const
&
dropout_results
,
torch
::
Tensor
const
&
input_lin_q_results
,
torch
::
Tensor
const
&
softmax_results
,
torch
::
Tensor
const
&
input_lin_kv_results
,
torch
::
Tensor
const
&
input_lin_q_results
,
torch
::
Tensor
const
&
lyr_nrm_results
,
torch
::
Tensor
const
&
lyr_nrm_mean
,
torch
::
Tensor
const
&
input_lin_kv_results
,
torch
::
Tensor
const
&
lyr_nrm_invvar
,
torch
::
Tensor
const
&
inputs_q
,
torch
::
Tensor
const
&
lyr_nrm_results
,
torch
::
Tensor
const
&
inputs_kv
,
torch
::
Tensor
const
&
lyr_nrm_gamma_weights
,
torch
::
Tensor
const
&
lyr_nrm_mean
,
torch
::
Tensor
const
&
lyr_nrm_beta_weights
,
torch
::
Tensor
const
&
lyr_nrm_invvar
,
torch
::
Tensor
const
&
input_weights_q
,
torch
::
Tensor
const
&
input_weights_kv
,
torch
::
Tensor
const
&
inputs_q
,
torch
::
Tensor
const
&
output_weights
,
torch
::
Tensor
const
&
dropout_mask
,
torch
::
Tensor
const
&
inputs_kv
,
torch
::
Tensor
const
&
dropout_add_mask
,
float
dropout_prob
)
{
torch
::
Tensor
const
&
lyr_nrm_gamma_weights
,
torch
::
Tensor
const
&
lyr_nrm_beta_weights
,
torch
::
Tensor
const
&
input_weights_q
,
torch
::
Tensor
const
&
input_weights_kv
,
torch
::
Tensor
const
&
output_weights
,
torch
::
Tensor
const
&
dropout_mask
,
torch
::
Tensor
const
&
dropout_add_mask
,
float
dropout_prob
)
{
const
int
embed_dim
=
inputs_q
.
size
(
2
);
const
int
embed_dim
=
inputs_q
.
size
(
2
);
const
int
sequences
=
inputs_q
.
size
(
1
);
const
int
sequences
=
inputs_q
.
size
(
1
);
const
int
q_seq_len
=
inputs_q
.
size
(
0
);
const
int
q_seq_len
=
inputs_q
.
size
(
0
);
...
@@ -343,7 +322,7 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -343,7 +322,7 @@ std::vector<torch::Tensor> bwd_cuda(
const
int
output_lin_kv_dim
=
2
*
embed_dim
;
const
int
output_lin_kv_dim
=
2
*
embed_dim
;
const
int
attn_batches
=
heads
*
sequences
;
const
int
attn_batches
=
heads
*
sequences
;
const
int
lead_dim_q
=
attn_batches
*
head_dim
;
const
int
lead_dim_q
=
attn_batches
*
head_dim
;
const
int
lead_dim_kv
=
attn_batches
*
2
*
head_dim
;
const
int
lead_dim_kv
=
attn_batches
*
2
*
head_dim
;
const
int
batch_stride_q
=
head_dim
;
const
int
batch_stride_q
=
head_dim
;
const
int
batch_stride_kv
=
2
*
head_dim
;
const
int
batch_stride_kv
=
2
*
head_dim
;
const
int
dropout_elems
=
attn_batches
*
q_seq_len
*
k_seq_len
;
const
int
dropout_elems
=
attn_batches
*
q_seq_len
*
k_seq_len
;
...
@@ -370,16 +349,21 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -370,16 +349,21 @@ std::vector<torch::Tensor> bwd_cuda(
at
::
Tensor
output_lin_grads
=
torch
::
empty_like
(
matmul2_results
);
at
::
Tensor
output_lin_grads
=
torch
::
empty_like
(
matmul2_results
);
at
::
Tensor
matmul2_grads
=
torch
::
empty_like
(
dropout_results
);
at
::
Tensor
matmul2_grads
=
torch
::
empty_like
(
dropout_results
);
at
::
Tensor
input_lin_q_output_grads
=
torch
::
empty_like
(
input_lin_q_results
);
at
::
Tensor
input_lin_q_output_grads
=
torch
::
empty_like
(
input_lin_q_results
);
at
::
Tensor
input_lin_kv_output_grads
=
torch
::
empty_like
(
input_lin_kv_results
);
at
::
Tensor
input_lin_kv_output_grads
=
torch
::
empty_like
(
input_lin_kv_results
);
at
::
Tensor
input_lin_q_grads
=
torch
::
empty_like
(
inputs_q
);
at
::
Tensor
input_lin_q_grads
=
torch
::
empty_like
(
inputs_q
);
auto
q_lin_results_ptr
=
static_cast
<
half
*>
(
input_lin_q_results
.
data_ptr
());
auto
q_lin_results_ptr
=
static_cast
<
half
*>
(
input_lin_q_results
.
data_ptr
());
auto
k_lin_results_ptr
=
static_cast
<
half
*>
(
input_lin_kv_results
.
data_ptr
());
auto
k_lin_results_ptr
=
static_cast
<
half
*>
(
input_lin_kv_results
.
data_ptr
());
auto
v_lin_results_ptr
=
static_cast
<
half
*>
(
input_lin_kv_results
.
data_ptr
())
+
head_dim
;
auto
v_lin_results_ptr
=
static_cast
<
half
*>
(
input_lin_kv_results
.
data_ptr
())
+
head_dim
;
auto
q_lin_grads_ptr
=
static_cast
<
half
*>
(
input_lin_q_output_grads
.
data_ptr
());
auto
q_lin_grads_ptr
=
auto
k_lin_grads_ptr
=
static_cast
<
half
*>
(
input_lin_kv_output_grads
.
data_ptr
());
static_cast
<
half
*>
(
input_lin_q_output_grads
.
data_ptr
());
auto
v_lin_grads_ptr
=
static_cast
<
half
*>
(
input_lin_kv_output_grads
.
data_ptr
())
+
head_dim
;
auto
k_lin_grads_ptr
=
static_cast
<
half
*>
(
input_lin_kv_output_grads
.
data_ptr
());
auto
v_lin_grads_ptr
=
static_cast
<
half
*>
(
input_lin_kv_output_grads
.
data_ptr
())
+
head_dim
;
char
a_layout_n
{
'n'
};
char
a_layout_n
{
'n'
};
char
a_layout_t
{
't'
};
char
a_layout_t
{
't'
};
...
@@ -449,8 +433,7 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -449,8 +433,7 @@ std::vector<torch::Tensor> bwd_cuda(
flags
));
flags
));
// MatMul2 Dgrad1
// MatMul2 Dgrad1
gemm_switch_fp32accum
(
state
,
gemm_switch_fp32accum
(
a_layout_t
,
a_layout_t
,
b_layout_n
,
b_layout_n
,
k_seq_len
,
k_seq_len
,
q_seq_len
,
q_seq_len
,
...
@@ -472,8 +455,7 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -472,8 +455,7 @@ std::vector<torch::Tensor> bwd_cuda(
attn_batches
);
attn_batches
);
// Matmul2 Dgrad2
// Matmul2 Dgrad2
gemm_switch_fp32accum
(
state
,
gemm_switch_fp32accum
(
a_layout_n
,
a_layout_n
,
b_layout_t
,
b_layout_t
,
head_dim
,
head_dim
,
k_seq_len
,
k_seq_len
,
...
@@ -505,17 +487,14 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -505,17 +487,14 @@ std::vector<torch::Tensor> bwd_cuda(
// Softmax Grad
// Softmax Grad
bool
softmax_success
=
false
;
bool
softmax_success
=
false
;
softmax_success
=
dispatch_softmax_backward
<
half
,
half
,
float
>
(
softmax_success
=
dispatch_softmax_backward
<
half
,
half
,
float
>
(
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
reinterpret_cast
<
half
const
*>
(
softmax_results
.
data_ptr
()),
reinterpret_cast
<
half
const
*>
(
softmax_results
.
data_ptr
()),
k_seq_len
,
k_seq_len
,
k_seq_len
,
attn_batches
*
q_seq_len
);
k_seq_len
,
attn_batches
*
q_seq_len
);
assert
(
softmax_success
);
assert
(
softmax_success
);
// Matmul1 Dgrad1
// Matmul1 Dgrad1
gemm_switch_fp32accum
(
state
,
gemm_switch_fp32accum
(
a_layout_n
,
a_layout_n
,
b_layout_n
,
b_layout_n
,
head_dim
,
head_dim
,
q_seq_len
,
q_seq_len
,
...
@@ -537,8 +516,7 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -537,8 +516,7 @@ std::vector<torch::Tensor> bwd_cuda(
attn_batches
);
attn_batches
);
// Matmul1 Dgrad2
// Matmul1 Dgrad2
gemm_switch_fp32accum
(
state
,
gemm_switch_fp32accum
(
a_layout_n
,
a_layout_n
,
b_layout_t
,
b_layout_t
,
head_dim
,
head_dim
,
k_seq_len
,
k_seq_len
,
...
@@ -683,17 +661,12 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -683,17 +661,12 @@ std::vector<torch::Tensor> bwd_cuda(
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return
{
return
{
input_q_grads
,
input_kv_grads
,
lyr_nrm_gamma_grads
,
input_q_grads
,
lyr_nrm_beta_grads
,
input_weight_q_grads
,
input_weight_kv_grads
,
input_kv_grads
,
output_weight_grads
};
lyr_nrm_gamma_grads
,
lyr_nrm_beta_grads
,
input_weight_q_grads
,
input_weight_kv_grads
,
output_weight_grads
};
}
}
}
// end namespace rocblas_gemmex
}
// end namespace rocblas_gemmex
}
// end namespace encdec_norm_add
}
// end namespace encdec_norm_add
}
// end namespace multihead_attn
}
// end namespace multihead_attn
apex/contrib/csrc/multihead_attn/layer_norm.h
View file @
db92ee13
...
@@ -4,14 +4,8 @@
...
@@ -4,14 +4,8 @@
#include <cuda.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_runtime.h>
template
<
typename
U
>
template
<
typename
U
>
__device__
__device__
void
cuWelfordOnlineSum
(
const
U
curr
,
U
&
mu
,
U
&
sigma2
,
U
&
count
)
{
void
cuWelfordOnlineSum
(
const
U
curr
,
U
&
mu
,
U
&
sigma2
,
U
&
count
)
{
count
=
count
+
U
(
1
);
count
=
count
+
U
(
1
);
U
delta
=
curr
-
mu
;
U
delta
=
curr
-
mu
;
U
lmean
=
mu
+
delta
/
count
;
U
lmean
=
mu
+
delta
/
count
;
...
@@ -20,15 +14,9 @@ void cuWelfordOnlineSum(
...
@@ -20,15 +14,9 @@ void cuWelfordOnlineSum(
sigma2
=
sigma2
+
delta
*
delta2
;
sigma2
=
sigma2
+
delta
*
delta2
;
}
}
template
<
typename
U
>
__device__
template
<
typename
U
>
void
cuChanOnlineSum
(
__device__
void
cuChanOnlineSum
(
const
U
muB
,
const
U
sigma2B
,
const
U
countB
,
const
U
muB
,
U
&
mu
,
U
&
sigma2
,
U
&
count
)
{
const
U
sigma2B
,
const
U
countB
,
U
&
mu
,
U
&
sigma2
,
U
&
count
)
{
U
delta
=
muB
-
mu
;
U
delta
=
muB
-
mu
;
U
nA
=
count
;
U
nA
=
count
;
U
nB
=
countB
;
U
nB
=
countB
;
...
@@ -37,7 +25,7 @@ void cuChanOnlineSum(
...
@@ -37,7 +25,7 @@ void cuChanOnlineSum(
if
(
nX
>
U
(
0
))
{
if
(
nX
>
U
(
0
))
{
nA
=
nA
/
nX
;
nA
=
nA
/
nX
;
nB
=
nB
/
nX
;
nB
=
nB
/
nX
;
mu
=
nA
*
mu
+
nB
*
muB
;
mu
=
nA
*
mu
+
nB
*
muB
;
sigma2
=
sigma2
+
sigma2B
+
delta
*
delta
*
nA
*
nB
*
nX
;
sigma2
=
sigma2
+
sigma2B
+
delta
*
delta
*
nA
*
nB
*
nX
;
}
else
{
}
else
{
mu
=
U
(
0
);
mu
=
U
(
0
);
...
@@ -45,16 +33,10 @@ void cuChanOnlineSum(
...
@@ -45,16 +33,10 @@ void cuChanOnlineSum(
}
}
}
}
template
<
typename
T
,
typename
U
>
__device__
template
<
typename
T
,
typename
U
>
void
cuWelfordMuSigma2
(
__device__
void
cuWelfordMuSigma2
(
const
T
*
__restrict__
vals
,
const
int
n1
,
const
T
*
__restrict__
vals
,
const
int
n2
,
const
int
i1
,
U
&
mu
,
U
&
sigma2
,
const
int
n1
,
U
*
buf
)
{
const
int
n2
,
const
int
i1
,
U
&
mu
,
U
&
sigma2
,
U
*
buf
)
{
// Assumptions:
// Assumptions:
// 1) blockDim.x == warpSize
// 1) blockDim.x == warpSize
// 2) Tensor is contiguous
// 2) Tensor is contiguous
...
@@ -62,7 +44,7 @@ void cuWelfordMuSigma2(
...
@@ -62,7 +44,7 @@ void cuWelfordMuSigma2(
//
//
// compute variance and mean over n2
// compute variance and mean over n2
U
count
=
U
(
0
);
U
count
=
U
(
0
);
mu
=
U
(
0
);
mu
=
U
(
0
);
sigma2
=
U
(
0
);
sigma2
=
U
(
0
);
if
(
i1
<
n1
)
{
if
(
i1
<
n1
)
{
// one warp normalizes one n1 index,
// one warp normalizes one n1 index,
...
@@ -70,17 +52,17 @@ void cuWelfordMuSigma2(
...
@@ -70,17 +52,17 @@ void cuWelfordMuSigma2(
// initialize with standard Welford algorithm
// initialize with standard Welford algorithm
const
int
numx
=
blockDim
.
x
*
blockDim
.
y
;
const
int
numx
=
blockDim
.
x
*
blockDim
.
y
;
const
int
thrx
=
threadIdx
.
x
+
threadIdx
.
y
*
blockDim
.
x
;
const
int
thrx
=
threadIdx
.
x
+
threadIdx
.
y
*
blockDim
.
x
;
const
T
*
lvals
=
vals
+
i1
*
n2
;
const
T
*
lvals
=
vals
+
i1
*
n2
;
int
l
=
4
*
thrx
;
int
l
=
4
*
thrx
;
for
(;
l
+
3
<
n2
;
l
+=
4
*
numx
)
{
for
(;
l
+
3
<
n2
;
l
+=
4
*
numx
)
{
for
(
int
k
=
0
;
k
<
4
;
++
k
)
{
for
(
int
k
=
0
;
k
<
4
;
++
k
)
{
U
curr
=
static_cast
<
U
>
(
lvals
[
l
+
k
]);
U
curr
=
static_cast
<
U
>
(
lvals
[
l
+
k
]);
cuWelfordOnlineSum
<
U
>
(
curr
,
mu
,
sigma2
,
count
);
cuWelfordOnlineSum
<
U
>
(
curr
,
mu
,
sigma2
,
count
);
}
}
}
}
for
(;
l
<
n2
;
++
l
)
{
for
(;
l
<
n2
;
++
l
)
{
U
curr
=
static_cast
<
U
>
(
lvals
[
l
]);
U
curr
=
static_cast
<
U
>
(
lvals
[
l
]);
cuWelfordOnlineSum
<
U
>
(
curr
,
mu
,
sigma2
,
count
);
cuWelfordOnlineSum
<
U
>
(
curr
,
mu
,
sigma2
,
count
);
}
}
// intra-warp reductions
// intra-warp reductions
for
(
int
l
=
0
;
l
<=
4
;
++
l
)
{
for
(
int
l
=
0
;
l
<=
4
;
++
l
)
{
...
@@ -93,23 +75,24 @@ void cuWelfordMuSigma2(
...
@@ -93,23 +75,24 @@ void cuWelfordMuSigma2(
// threadIdx.x == 0 has correct values for each warp
// threadIdx.x == 0 has correct values for each warp
// inter-warp reductions
// inter-warp reductions
if
(
blockDim
.
y
>
1
)
{
if
(
blockDim
.
y
>
1
)
{
U
*
ubuf
=
(
U
*
)
buf
;
U
*
ubuf
=
(
U
*
)
buf
;
U
*
ibuf
=
(
U
*
)(
ubuf
+
blockDim
.
y
);
U
*
ibuf
=
(
U
*
)(
ubuf
+
blockDim
.
y
);
for
(
int
offset
=
blockDim
.
y
/
2
;
offset
>
0
;
offset
/=
2
)
{
for
(
int
offset
=
blockDim
.
y
/
2
;
offset
>
0
;
offset
/=
2
)
{
// upper half of warps write to shared
// upper half of warps write to shared
if
(
threadIdx
.
x
==
0
&&
threadIdx
.
y
>=
offset
&&
threadIdx
.
y
<
2
*
offset
)
{
if
(
threadIdx
.
x
==
0
&&
threadIdx
.
y
>=
offset
&&
threadIdx
.
y
<
2
*
offset
)
{
const
int
wrt_y
=
threadIdx
.
y
-
offset
;
const
int
wrt_y
=
threadIdx
.
y
-
offset
;
ubuf
[
2
*
wrt_y
]
=
mu
;
ubuf
[
2
*
wrt_y
]
=
mu
;
ubuf
[
2
*
wrt_y
+
1
]
=
sigma2
;
ubuf
[
2
*
wrt_y
+
1
]
=
sigma2
;
ibuf
[
wrt_y
]
=
count
;
ibuf
[
wrt_y
]
=
count
;
}
}
__syncthreads
();
__syncthreads
();
// lower half merges
// lower half merges
if
(
threadIdx
.
x
==
0
&&
threadIdx
.
y
<
offset
)
{
if
(
threadIdx
.
x
==
0
&&
threadIdx
.
y
<
offset
)
{
U
muB
=
ubuf
[
2
*
threadIdx
.
y
];
U
muB
=
ubuf
[
2
*
threadIdx
.
y
];
U
sigma2B
=
ubuf
[
2
*
threadIdx
.
y
+
1
];
U
sigma2B
=
ubuf
[
2
*
threadIdx
.
y
+
1
];
U
countB
=
ibuf
[
threadIdx
.
y
];
U
countB
=
ibuf
[
threadIdx
.
y
];
cuChanOnlineSum
<
U
>
(
muB
,
sigma2B
,
countB
,
mu
,
sigma2
,
count
);
cuChanOnlineSum
<
U
>
(
muB
,
sigma2B
,
countB
,
mu
,
sigma2
,
count
);
}
}
__syncthreads
();
__syncthreads
();
}
}
...
@@ -120,7 +103,7 @@ void cuWelfordMuSigma2(
...
@@ -120,7 +103,7 @@ void cuWelfordMuSigma2(
}
}
__syncthreads
();
__syncthreads
();
mu
=
ubuf
[
0
];
mu
=
ubuf
[
0
];
sigma2
=
ubuf
[
1
]
/
U
(
n2
);
sigma2
=
ubuf
[
1
]
/
U
(
n2
);
// don't care about final value of count, we know count == n2
// don't care about final value of count, we know count == n2
}
else
{
}
else
{
mu
=
WARP_SHFL
(
mu
,
0
,
32
);
mu
=
WARP_SHFL
(
mu
,
0
,
32
);
...
@@ -129,16 +112,10 @@ void cuWelfordMuSigma2(
...
@@ -129,16 +112,10 @@ void cuWelfordMuSigma2(
}
}
}
}
template
<
>
__device__
template
<
>
void
cuWelfordMuSigma2
(
__device__
void
cuWelfordMuSigma2
(
const
at
::
Half
*
__restrict__
vals
,
const
at
::
Half
*
__restrict__
vals
,
const
int
n1
,
const
int
n2
,
const
int
i1
,
const
int
n1
,
float
&
mu
,
float
&
sigma2
,
float
*
buf
)
{
const
int
n2
,
const
int
i1
,
float
&
mu
,
float
&
sigma2
,
float
*
buf
)
{
// Assumptions:
// Assumptions:
// 1) blockDim.x == warpSize
// 1) blockDim.x == warpSize
// 2) Tensor is contiguous
// 2) Tensor is contiguous
...
@@ -146,7 +123,7 @@ void cuWelfordMuSigma2(
...
@@ -146,7 +123,7 @@ void cuWelfordMuSigma2(
//
//
// compute variance and mean over n2
// compute variance and mean over n2
float
count
=
0.0
f
;
float
count
=
0.0
f
;
mu
=
float
(
0
);
mu
=
float
(
0
);
sigma2
=
float
(
0
);
sigma2
=
float
(
0
);
if
(
i1
<
n1
)
{
if
(
i1
<
n1
)
{
...
@@ -155,28 +132,28 @@ void cuWelfordMuSigma2(
...
@@ -155,28 +132,28 @@ void cuWelfordMuSigma2(
// initialize with standard Welford algorithm
// initialize with standard Welford algorithm
const
int
numx
=
blockDim
.
x
*
blockDim
.
y
;
const
int
numx
=
blockDim
.
x
*
blockDim
.
y
;
const
int
thrx
=
threadIdx
.
x
+
threadIdx
.
y
*
blockDim
.
x
;
const
int
thrx
=
threadIdx
.
x
+
threadIdx
.
y
*
blockDim
.
x
;
const
at
::
Half
*
lvals
=
vals
+
i1
*
n2
;
const
at
::
Half
*
lvals
=
vals
+
i1
*
n2
;
int
l
=
8
*
thrx
;
int
l
=
8
*
thrx
;
if
((((
size_t
)
lvals
)
&
3
)
!=
0
)
{
if
((((
size_t
)
lvals
)
&
3
)
!=
0
)
{
// 16 bit alignment
// 16 bit alignment
// first thread consumes first point
// first thread consumes first point
if
(
thrx
==
0
)
{
if
(
thrx
==
0
)
{
float
curr
=
static_cast
<
float
>
(
lvals
[
0
]);
float
curr
=
static_cast
<
float
>
(
lvals
[
0
]);
cuWelfordOnlineSum
(
curr
,
mu
,
sigma2
,
count
);
cuWelfordOnlineSum
(
curr
,
mu
,
sigma2
,
count
);
}
}
++
l
;
++
l
;
}
}
// at this point, lvals[l] are 32 bit aligned for all threads.
// at this point, lvals[l] are 32 bit aligned for all threads.
for
(;
l
+
7
<
n2
;
l
+=
8
*
numx
)
{
for
(;
l
+
7
<
n2
;
l
+=
8
*
numx
)
{
for
(
int
k
=
0
;
k
<
8
;
k
+=
2
)
{
for
(
int
k
=
0
;
k
<
8
;
k
+=
2
)
{
float2
curr
=
__half22float2
(
*
((
__half2
*
)(
lvals
+
l
+
k
)));
float2
curr
=
__half22float2
(
*
((
__half2
*
)(
lvals
+
l
+
k
)));
cuWelfordOnlineSum
(
curr
.
x
,
mu
,
sigma2
,
count
);
cuWelfordOnlineSum
(
curr
.
x
,
mu
,
sigma2
,
count
);
cuWelfordOnlineSum
(
curr
.
y
,
mu
,
sigma2
,
count
);
cuWelfordOnlineSum
(
curr
.
y
,
mu
,
sigma2
,
count
);
}
}
}
}
for
(;
l
<
n2
;
++
l
)
{
for
(;
l
<
n2
;
++
l
)
{
float
curr
=
static_cast
<
float
>
(
lvals
[
l
]);
float
curr
=
static_cast
<
float
>
(
lvals
[
l
]);
cuWelfordOnlineSum
(
curr
,
mu
,
sigma2
,
count
);
cuWelfordOnlineSum
(
curr
,
mu
,
sigma2
,
count
);
}
}
// intra-warp reductions
// intra-warp reductions
for
(
int
l
=
0
;
l
<=
4
;
++
l
)
{
for
(
int
l
=
0
;
l
<=
4
;
++
l
)
{
...
@@ -189,23 +166,24 @@ void cuWelfordMuSigma2(
...
@@ -189,23 +166,24 @@ void cuWelfordMuSigma2(
// threadIdx.x == 0 has correct values for each warp
// threadIdx.x == 0 has correct values for each warp
// inter-warp reductions
// inter-warp reductions
if
(
blockDim
.
y
>
1
)
{
if
(
blockDim
.
y
>
1
)
{
float
*
ubuf
=
(
float
*
)
buf
;
float
*
ubuf
=
(
float
*
)
buf
;
float
*
ibuf
=
(
float
*
)(
ubuf
+
blockDim
.
y
);
float
*
ibuf
=
(
float
*
)(
ubuf
+
blockDim
.
y
);
for
(
int
offset
=
blockDim
.
y
/
2
;
offset
>
0
;
offset
/=
2
)
{
for
(
int
offset
=
blockDim
.
y
/
2
;
offset
>
0
;
offset
/=
2
)
{
// upper half of warps write to shared
// upper half of warps write to shared
if
(
threadIdx
.
x
==
0
&&
threadIdx
.
y
>=
offset
&&
threadIdx
.
y
<
2
*
offset
)
{
if
(
threadIdx
.
x
==
0
&&
threadIdx
.
y
>=
offset
&&
threadIdx
.
y
<
2
*
offset
)
{
const
int
wrt_y
=
threadIdx
.
y
-
offset
;
const
int
wrt_y
=
threadIdx
.
y
-
offset
;
ubuf
[
2
*
wrt_y
]
=
mu
;
ubuf
[
2
*
wrt_y
]
=
mu
;
ubuf
[
2
*
wrt_y
+
1
]
=
sigma2
;
ubuf
[
2
*
wrt_y
+
1
]
=
sigma2
;
ibuf
[
wrt_y
]
=
count
;
ibuf
[
wrt_y
]
=
count
;
}
}
__syncthreads
();
__syncthreads
();
// lower half merges
// lower half merges
if
(
threadIdx
.
x
==
0
&&
threadIdx
.
y
<
offset
)
{
if
(
threadIdx
.
x
==
0
&&
threadIdx
.
y
<
offset
)
{
float
muB
=
ubuf
[
2
*
threadIdx
.
y
];
float
muB
=
ubuf
[
2
*
threadIdx
.
y
];
float
sigma2B
=
ubuf
[
2
*
threadIdx
.
y
+
1
];
float
sigma2B
=
ubuf
[
2
*
threadIdx
.
y
+
1
];
float
countB
=
ibuf
[
threadIdx
.
y
];
float
countB
=
ibuf
[
threadIdx
.
y
];
cuChanOnlineSum
(
muB
,
sigma2B
,
countB
,
mu
,
sigma2
,
count
);
cuChanOnlineSum
(
muB
,
sigma2B
,
countB
,
mu
,
sigma2
,
count
);
}
}
__syncthreads
();
__syncthreads
();
}
}
...
@@ -216,7 +194,7 @@ void cuWelfordMuSigma2(
...
@@ -216,7 +194,7 @@ void cuWelfordMuSigma2(
}
}
__syncthreads
();
__syncthreads
();
mu
=
ubuf
[
0
];
mu
=
ubuf
[
0
];
sigma2
=
ubuf
[
1
]
/
float
(
n2
);
sigma2
=
ubuf
[
1
]
/
float
(
n2
);
// don't care about final value of count, we know count == n2
// don't care about final value of count, we know count == n2
}
else
{
}
else
{
mu
=
WARP_SHFL
(
mu
,
0
,
32
);
mu
=
WARP_SHFL
(
mu
,
0
,
32
);
...
@@ -246,8 +224,9 @@ template<> double rsqrt(double v) {
...
@@ -246,8 +224,9 @@ template<> double rsqrt(double v) {
}
}
namespace
{
namespace
{
// This is the un-specialized struct. Note that we prevent instantiation of this
// This is the un-specialized struct. Note that we prevent instantiation of
// struct by putting an undefined symbol in the function body so it won't compile.
// this struct by putting an undefined symbol in the function body so it won't
// compile.
// template <typename T>
// template <typename T>
// struct SharedMemory
// struct SharedMemory
// {
// {
...
@@ -260,64 +239,50 @@ namespace {
...
@@ -260,64 +239,50 @@ namespace {
// }
// }
// };
// };
// https://github.com/NVIDIA/apex/issues/246
// https://github.com/NVIDIA/apex/issues/246
template
<
typename
T
>
template
<
typename
T
>
struct
SharedMemory
;
struct
SharedMemory
;
template
<
>
template
<
>
struct
SharedMemory
<
float
>
{
struct
SharedMemory
<
float
>
__device__
float
*
getPointer
()
{
{
__device__
float
*
getPointer
()
{
extern
__shared__
float
s_float
[];
extern
__shared__
float
s_float
[];
return
s_float
;
return
s_float
;
}
}
};
};
template
<
>
template
<
>
struct
SharedMemory
<
double
>
{
struct
SharedMemory
<
double
>
__device__
double
*
getPointer
()
{
{
__device__
double
*
getPointer
()
{
extern
__shared__
double
s_double
[];
extern
__shared__
double
s_double
[];
return
s_double
;
return
s_double
;
}
}
};
};
}
}
// namespace
template
<
typename
T
,
typename
U
>
__global__
template
<
typename
T
,
typename
U
>
void
cuApplyLayerNorm
(
__global__
void
T
*
__restrict__
output_vals
,
cuApplyLayerNorm
(
T
*
__restrict__
output_vals
,
U
*
__restrict__
mean
,
U
*
__restrict__
mean
,
U
*
__restrict__
invvar
,
const
T
*
__restrict__
vals
,
U
*
__restrict__
invvar
,
const
int
n1
,
const
int
n2
,
const
U
epsilon
,
const
T
*
__restrict__
vals
,
const
T
*
__restrict__
gamma
,
const
T
*
__restrict__
beta
)
{
const
int
n1
,
const
int
n2
,
const
U
epsilon
,
const
T
*
__restrict__
gamma
,
const
T
*
__restrict__
beta
)
{
// Assumptions:
// Assumptions:
// 1) blockDim.x == warpSize
// 1) blockDim.x == warpSize
// 2) Tensors are contiguous
// 2) Tensors are contiguous
//
//
for
(
int
i1
=
blockIdx
.
y
;
i1
<
n1
;
i1
+=
gridDim
.
y
)
{
for
(
int
i1
=
blockIdx
.
y
;
i1
<
n1
;
i1
+=
gridDim
.
y
)
{
SharedMemory
<
U
>
shared
;
SharedMemory
<
U
>
shared
;
U
*
buf
=
shared
.
getPointer
();
U
*
buf
=
shared
.
getPointer
();
U
mu
,
sigma2
;
U
mu
,
sigma2
;
cuWelfordMuSigma2
(
vals
,
n1
,
n2
,
i1
,
mu
,
sigma2
,
buf
);
cuWelfordMuSigma2
(
vals
,
n1
,
n2
,
i1
,
mu
,
sigma2
,
buf
);
const
T
*
lvals
=
vals
+
i1
*
n2
;
const
T
*
lvals
=
vals
+
i1
*
n2
;
T
*
ovals
=
output_vals
+
i1
*
n2
;
T
*
ovals
=
output_vals
+
i1
*
n2
;
U
c_invvar
=
rsqrt
(
sigma2
+
epsilon
);
U
c_invvar
=
rsqrt
(
sigma2
+
epsilon
);
const
int
numx
=
blockDim
.
x
*
blockDim
.
y
;
const
int
numx
=
blockDim
.
x
*
blockDim
.
y
;
const
int
thrx
=
threadIdx
.
x
+
threadIdx
.
y
*
blockDim
.
x
;
const
int
thrx
=
threadIdx
.
x
+
threadIdx
.
y
*
blockDim
.
x
;
if
(
gamma
!=
NULL
&&
beta
!=
NULL
)
{
if
(
gamma
!=
NULL
&&
beta
!=
NULL
)
{
for
(
int
i
=
thrx
;
i
<
n2
;
i
+=
numx
)
{
for
(
int
i
=
thrx
;
i
<
n2
;
i
+=
numx
)
{
U
curr
=
static_cast
<
U
>
(
lvals
[
i
]);
U
curr
=
static_cast
<
U
>
(
lvals
[
i
]);
ovals
[
i
]
=
gamma
[
i
]
*
static_cast
<
T
>
(
c_invvar
*
(
curr
-
mu
))
+
beta
[
i
];
ovals
[
i
]
=
gamma
[
i
]
*
static_cast
<
T
>
(
c_invvar
*
(
curr
-
mu
))
+
beta
[
i
];
}
}
}
else
{
}
else
{
for
(
int
i
=
thrx
;
i
<
n2
;
i
+=
numx
)
{
for
(
int
i
=
thrx
;
i
<
n2
;
i
+=
numx
)
{
U
curr
=
static_cast
<
U
>
(
lvals
[
i
]);
U
curr
=
static_cast
<
U
>
(
lvals
[
i
]);
ovals
[
i
]
=
static_cast
<
T
>
(
c_invvar
*
(
curr
-
mu
));
ovals
[
i
]
=
static_cast
<
T
>
(
c_invvar
*
(
curr
-
mu
));
}
}
...
@@ -329,36 +294,26 @@ void cuApplyLayerNorm(
...
@@ -329,36 +294,26 @@ void cuApplyLayerNorm(
}
}
}
}
template
<
typename
T
,
typename
U
>
__device__
template
<
typename
T
,
typename
U
>
void
cuLoadWriteStridedInputs
(
__device__
void
cuLoadWriteStridedInputs
(
const
int
i1_block
,
const
int
i1_block
,
const
int
thr_load_row_off
,
const
int
thr_load_col_off
,
const
int
thr_load_row_off
,
const
int
i2_off
,
const
int
row_stride
,
U
*
warp_buf1
,
U
*
warp_buf2
,
const
int
thr_load_col_off
,
const
T
*
input
,
const
T
*
dout
,
const
int
i1_end
,
const
int
n2
,
const
int
i2_off
,
const
U
*
__restrict__
mean
,
const
U
*
__restrict__
invvar
)
{
const
int
row_stride
,
int
i1
=
i1_block
+
thr_load_row_off
;
U
*
warp_buf1
,
U
*
warp_buf2
,
const
T
*
input
,
const
T
*
dout
,
const
int
i1_end
,
const
int
n2
,
const
U
*
__restrict__
mean
,
const
U
*
__restrict__
invvar
)
{
int
i1
=
i1_block
+
thr_load_row_off
;
if
(
i1
<
i1_end
)
{
if
(
i1
<
i1_end
)
{
U
curr_mean
=
mean
[
i1
];
U
curr_mean
=
mean
[
i1
];
U
curr_invvar
=
invvar
[
i1
];
U
curr_invvar
=
invvar
[
i1
];
for
(
int
k
=
0
;
k
<
blockDim
.
y
;
++
k
)
{
for
(
int
k
=
0
;
k
<
blockDim
.
y
;
++
k
)
{
int
i2
=
i2_off
+
k
;
int
i2
=
i2_off
+
k
;
int
load_idx
=
i1
*
n2
+
i2
;
int
load_idx
=
i1
*
n2
+
i2
;
int
write_idx
=
thr_load_row_off
*
row_stride
+
thr_load_col_off
+
k
;
int
write_idx
=
thr_load_row_off
*
row_stride
+
thr_load_col_off
+
k
;
if
(
i2
<
n2
)
{
if
(
i2
<
n2
)
{
U
curr_input
=
static_cast
<
U
>
(
input
[
load_idx
]);
U
curr_input
=
static_cast
<
U
>
(
input
[
load_idx
]);
U
curr_dout
=
static_cast
<
U
>
(
dout
[
load_idx
]);
U
curr_dout
=
static_cast
<
U
>
(
dout
[
load_idx
]);
warp_buf1
[
write_idx
]
=
curr_dout
;
warp_buf1
[
write_idx
]
=
curr_dout
;
warp_buf2
[
write_idx
]
=
curr_dout
*
(
curr_input
-
curr_mean
)
*
curr_invvar
;
warp_buf2
[
write_idx
]
=
curr_dout
*
(
curr_input
-
curr_mean
)
*
curr_invvar
;
}
else
{
}
else
{
warp_buf1
[
write_idx
]
=
U
(
0
);
warp_buf1
[
write_idx
]
=
U
(
0
);
warp_buf2
[
write_idx
]
=
U
(
0
);
warp_buf2
[
write_idx
]
=
U
(
0
);
...
@@ -366,78 +321,71 @@ void cuLoadWriteStridedInputs(
...
@@ -366,78 +321,71 @@ void cuLoadWriteStridedInputs(
}
}
}
else
{
}
else
{
for
(
int
k
=
0
;
k
<
blockDim
.
y
;
++
k
)
{
for
(
int
k
=
0
;
k
<
blockDim
.
y
;
++
k
)
{
int
write_idx
=
thr_load_row_off
*
row_stride
+
thr_load_col_off
+
k
;
int
write_idx
=
thr_load_row_off
*
row_stride
+
thr_load_col_off
+
k
;
warp_buf1
[
write_idx
]
=
U
(
0
);
warp_buf1
[
write_idx
]
=
U
(
0
);
warp_buf2
[
write_idx
]
=
U
(
0
);
warp_buf2
[
write_idx
]
=
U
(
0
);
}
}
}
}
}
}
template
<
typename
T
,
typename
U
>
__device__
template
<
typename
T
,
typename
U
>
void
cuLoadAddStridedInputs
(
__device__
void
cuLoadAddStridedInputs
(
const
int
i1_block
,
const
int
i1_block
,
const
int
thr_load_row_off
,
const
int
thr_load_col_off
,
const
int
thr_load_row_off
,
const
int
i2_off
,
const
int
row_stride
,
U
*
warp_buf1
,
U
*
warp_buf2
,
const
int
thr_load_col_off
,
const
T
*
input
,
const
T
*
dout
,
const
int
i1_end
,
const
int
n2
,
const
int
i2_off
,
const
U
*
__restrict__
mean
,
const
U
*
__restrict__
invvar
)
{
const
int
row_stride
,
int
i1
=
i1_block
+
thr_load_row_off
;
U
*
warp_buf1
,
U
*
warp_buf2
,
const
T
*
input
,
const
T
*
dout
,
const
int
i1_end
,
const
int
n2
,
const
U
*
__restrict__
mean
,
const
U
*
__restrict__
invvar
)
{
int
i1
=
i1_block
+
thr_load_row_off
;
if
(
i1
<
i1_end
)
{
if
(
i1
<
i1_end
)
{
U
curr_mean
=
mean
[
i1
];
U
curr_mean
=
mean
[
i1
];
U
curr_invvar
=
invvar
[
i1
];
U
curr_invvar
=
invvar
[
i1
];
for
(
int
k
=
0
;
k
<
blockDim
.
y
;
++
k
)
{
for
(
int
k
=
0
;
k
<
blockDim
.
y
;
++
k
)
{
int
i2
=
i2_off
+
k
;
int
i2
=
i2_off
+
k
;
int
load_idx
=
i1
*
n2
+
i2
;
int
load_idx
=
i1
*
n2
+
i2
;
int
write_idx
=
thr_load_row_off
*
row_stride
+
thr_load_col_off
+
k
;
int
write_idx
=
thr_load_row_off
*
row_stride
+
thr_load_col_off
+
k
;
if
(
i2
<
n2
)
{
if
(
i2
<
n2
)
{
U
curr_input
=
static_cast
<
U
>
(
input
[
load_idx
]);
U
curr_input
=
static_cast
<
U
>
(
input
[
load_idx
]);
U
curr_dout
=
static_cast
<
U
>
(
dout
[
load_idx
]);
U
curr_dout
=
static_cast
<
U
>
(
dout
[
load_idx
]);
warp_buf1
[
write_idx
]
+=
curr_dout
;
warp_buf1
[
write_idx
]
+=
curr_dout
;
warp_buf2
[
write_idx
]
+=
curr_dout
*
(
curr_input
-
curr_mean
)
*
curr_invvar
;
warp_buf2
[
write_idx
]
+=
curr_dout
*
(
curr_input
-
curr_mean
)
*
curr_invvar
;
}
}
}
}
}
}
}
}
template
<
typename
T
,
typename
U
>
__global__
template
<
typename
T
,
typename
U
>
void
cuComputePartGradGammaBeta
(
__global__
void
cuComputePartGradGammaBeta
(
const
T
*
__restrict__
dout
,
const
T
*
__restrict__
dout
,
const
T
*
__restrict__
input
,
const
int
n1
,
const
T
*
__restrict__
input
,
const
int
n2
,
const
U
*
__restrict__
mean
,
const
U
*
__restrict__
invvar
,
const
int
n1
,
U
epsilon
,
U
*
part_grad_gamma
,
U
*
part_grad_beta
)
{
const
int
n2
,
const
int
numsegs_n1
=
const
U
*
__restrict__
mean
,
(
n1
+
blockDim
.
y
*
blockDim
.
y
-
1
)
/
(
blockDim
.
y
*
blockDim
.
y
);
const
U
*
__restrict__
invvar
,
U
epsilon
,
U
*
part_grad_gamma
,
U
*
part_grad_beta
)
{
const
int
numsegs_n1
=
(
n1
+
blockDim
.
y
*
blockDim
.
y
-
1
)
/
(
blockDim
.
y
*
blockDim
.
y
);
const
int
segs_per_block
=
(
numsegs_n1
+
gridDim
.
y
-
1
)
/
gridDim
.
y
;
const
int
segs_per_block
=
(
numsegs_n1
+
gridDim
.
y
-
1
)
/
gridDim
.
y
;
const
int
i1_beg
=
blockIdx
.
y
*
segs_per_block
*
blockDim
.
y
*
blockDim
.
y
;
const
int
i1_beg
=
blockIdx
.
y
*
segs_per_block
*
blockDim
.
y
*
blockDim
.
y
;
const
int
i1_beg_plus_one
=
(
blockIdx
.
y
+
1
)
*
segs_per_block
*
blockDim
.
y
*
blockDim
.
y
;
const
int
i1_beg_plus_one
=
(
blockIdx
.
y
+
1
)
*
segs_per_block
*
blockDim
.
y
*
blockDim
.
y
;
const
int
i1_end
=
i1_beg_plus_one
<
n1
?
i1_beg_plus_one
:
n1
;
const
int
i1_end
=
i1_beg_plus_one
<
n1
?
i1_beg_plus_one
:
n1
;
const
int
row_stride
=
blockDim
.
x
+
1
;
const
int
row_stride
=
blockDim
.
x
+
1
;
const
int
thr_load_col_off
=
(
threadIdx
.
x
*
blockDim
.
y
)
&
(
blockDim
.
x
-
1
);
const
int
thr_load_col_off
=
(
threadIdx
.
x
*
blockDim
.
y
)
&
(
blockDim
.
x
-
1
);
const
int
thr_load_row_off
=
(
threadIdx
.
x
*
blockDim
.
y
)
/
blockDim
.
x
+
threadIdx
.
y
*
blockDim
.
y
;
const
int
thr_load_row_off
=
(
threadIdx
.
x
*
blockDim
.
y
)
/
blockDim
.
x
+
threadIdx
.
y
*
blockDim
.
y
;
const
int
i2_off
=
blockIdx
.
x
*
blockDim
.
x
+
thr_load_col_off
;
const
int
i2_off
=
blockIdx
.
x
*
blockDim
.
x
+
thr_load_col_off
;
SharedMemory
<
U
>
shared
;
SharedMemory
<
U
>
shared
;
U
*
buf
=
shared
.
getPointer
();
// buf has at least blockDim.x * blockDim.y * blockDim.y + (blockDim.y - 1)*(blockDim.x/blockDim.y) elements
U
*
buf
=
shared
.
getPointer
();
// buf has at least blockDim.x * blockDim.y *
U
*
warp_buf1
=
(
U
*
)
buf
;
// blockDim.y + (blockDim.y -
U
*
warp_buf2
=
warp_buf1
+
blockDim
.
y
*
blockDim
.
y
*
row_stride
;
// 1)*(blockDim.x/blockDim.y) elements
U
*
warp_buf1
=
(
U
*
)
buf
;
U
*
warp_buf2
=
warp_buf1
+
blockDim
.
y
*
blockDim
.
y
*
row_stride
;
// compute partial sums from strided inputs
// compute partial sums from strided inputs
// do this to increase number of loads in flight
// do this to increase number of loads in flight
cuLoadWriteStridedInputs
(
i1_beg
,
thr_load_row_off
,
thr_load_col_off
,
i2_off
,
row_stride
,
warp_buf1
,
warp_buf2
,
input
,
dout
,
i1_end
,
n2
,
mean
,
invvar
);
cuLoadWriteStridedInputs
(
i1_beg
,
thr_load_row_off
,
thr_load_col_off
,
i2_off
,
for
(
int
i1_block
=
i1_beg
+
blockDim
.
y
*
blockDim
.
y
;
i1_block
<
i1_end
;
i1_block
+=
blockDim
.
y
*
blockDim
.
y
)
{
row_stride
,
warp_buf1
,
warp_buf2
,
input
,
dout
,
cuLoadAddStridedInputs
(
i1_block
,
thr_load_row_off
,
thr_load_col_off
,
i2_off
,
row_stride
,
warp_buf1
,
warp_buf2
,
input
,
dout
,
i1_end
,
n2
,
mean
,
invvar
);
i1_end
,
n2
,
mean
,
invvar
);
for
(
int
i1_block
=
i1_beg
+
blockDim
.
y
*
blockDim
.
y
;
i1_block
<
i1_end
;
i1_block
+=
blockDim
.
y
*
blockDim
.
y
)
{
cuLoadAddStridedInputs
(
i1_block
,
thr_load_row_off
,
thr_load_col_off
,
i2_off
,
row_stride
,
warp_buf1
,
warp_buf2
,
input
,
dout
,
i1_end
,
n2
,
mean
,
invvar
);
}
}
__syncthreads
();
__syncthreads
();
// inter-warp reductions
// inter-warp reductions
...
@@ -445,21 +393,21 @@ void cuComputePartGradGammaBeta(
...
@@ -445,21 +393,21 @@ void cuComputePartGradGammaBeta(
U
acc1
=
U
(
0
);
U
acc1
=
U
(
0
);
U
acc2
=
U
(
0
);
U
acc2
=
U
(
0
);
for
(
int
k
=
0
;
k
<
blockDim
.
y
;
++
k
)
{
for
(
int
k
=
0
;
k
<
blockDim
.
y
;
++
k
)
{
int
row1
=
threadIdx
.
y
+
k
*
blockDim
.
y
;
int
row1
=
threadIdx
.
y
+
k
*
blockDim
.
y
;
int
idx1
=
row1
*
row_stride
+
threadIdx
.
x
;
int
idx1
=
row1
*
row_stride
+
threadIdx
.
x
;
acc1
+=
warp_buf1
[
idx1
];
acc1
+=
warp_buf1
[
idx1
];
acc2
+=
warp_buf2
[
idx1
];
acc2
+=
warp_buf2
[
idx1
];
}
}
warp_buf1
[
threadIdx
.
y
*
row_stride
+
threadIdx
.
x
]
=
acc1
;
warp_buf1
[
threadIdx
.
y
*
row_stride
+
threadIdx
.
x
]
=
acc1
;
warp_buf2
[
threadIdx
.
y
*
row_stride
+
threadIdx
.
x
]
=
acc2
;
warp_buf2
[
threadIdx
.
y
*
row_stride
+
threadIdx
.
x
]
=
acc2
;
__syncthreads
();
__syncthreads
();
// sum all warps
// sum all warps
for
(
int
offset
=
blockDim
.
y
/
2
;
offset
>
1
;
offset
/=
2
)
{
for
(
int
offset
=
blockDim
.
y
/
2
;
offset
>
1
;
offset
/=
2
)
{
if
(
threadIdx
.
y
<
offset
)
{
if
(
threadIdx
.
y
<
offset
)
{
int
row1
=
threadIdx
.
y
;
int
row1
=
threadIdx
.
y
;
int
row2
=
threadIdx
.
y
+
offset
;
int
row2
=
threadIdx
.
y
+
offset
;
int
idx1
=
row1
*
row_stride
+
threadIdx
.
x
;
int
idx1
=
row1
*
row_stride
+
threadIdx
.
x
;
int
idx2
=
row2
*
row_stride
+
threadIdx
.
x
;
int
idx2
=
row2
*
row_stride
+
threadIdx
.
x
;
warp_buf1
[
idx1
]
+=
warp_buf1
[
idx2
];
warp_buf1
[
idx1
]
+=
warp_buf1
[
idx2
];
warp_buf2
[
idx1
]
+=
warp_buf2
[
idx2
];
warp_buf2
[
idx1
]
+=
warp_buf2
[
idx2
];
}
}
...
@@ -469,53 +417,51 @@ void cuComputePartGradGammaBeta(
...
@@ -469,53 +417,51 @@ void cuComputePartGradGammaBeta(
if
(
threadIdx
.
y
==
0
&&
i2
<
n2
)
{
if
(
threadIdx
.
y
==
0
&&
i2
<
n2
)
{
int
row1
=
threadIdx
.
y
;
int
row1
=
threadIdx
.
y
;
int
row2
=
threadIdx
.
y
+
1
;
int
row2
=
threadIdx
.
y
+
1
;
int
idx1
=
row1
*
row_stride
+
threadIdx
.
x
;
int
idx1
=
row1
*
row_stride
+
threadIdx
.
x
;
int
idx2
=
row2
*
row_stride
+
threadIdx
.
x
;
int
idx2
=
row2
*
row_stride
+
threadIdx
.
x
;
part_grad_beta
[
blockIdx
.
y
*
n2
+
i2
]
=
warp_buf1
[
idx1
]
+
warp_buf1
[
idx2
];
part_grad_beta
[
blockIdx
.
y
*
n2
+
i2
]
=
warp_buf1
[
idx1
]
+
warp_buf1
[
idx2
];
part_grad_gamma
[
blockIdx
.
y
*
n2
+
i2
]
=
warp_buf2
[
idx1
]
+
warp_buf2
[
idx2
];
part_grad_gamma
[
blockIdx
.
y
*
n2
+
i2
]
=
warp_buf2
[
idx1
]
+
warp_buf2
[
idx2
];
}
}
}
}
template
<
typename
T
,
typename
U
>
__global__
template
<
typename
T
,
typename
U
>
void
cuComputeGradGammaBeta
(
__global__
void
const
U
*
part_grad_gamma
,
cuComputeGradGammaBeta
(
const
U
*
part_grad_gamma
,
const
U
*
part_grad_beta
,
const
U
*
part_grad_beta
,
const
int
part_size
,
const
int
n1
,
const
int
n2
,
const
int
part_size
,
T
*
grad_gamma
,
T
*
grad_beta
)
{
const
int
n1
,
const
int
n2
,
T
*
grad_gamma
,
T
*
grad_beta
)
{
// sum partial gradients for gamma and beta
// sum partial gradients for gamma and beta
SharedMemory
<
U
>
shared
;
SharedMemory
<
U
>
shared
;
U
*
buf
=
shared
.
getPointer
();
U
*
buf
=
shared
.
getPointer
();
int
i2
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
i2
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
i2
<
n2
)
{
if
(
i2
<
n2
)
{
// each warp does sequential reductions until reduced part_size is num_warps
// each warp does sequential reductions until reduced part_size is num_warps
int
num_warp_reductions
=
part_size
/
blockDim
.
y
;
int
num_warp_reductions
=
part_size
/
blockDim
.
y
;
U
sum_gamma
=
U
(
0
);
U
sum_gamma
=
U
(
0
);
U
sum_beta
=
U
(
0
);
U
sum_beta
=
U
(
0
);
const
U
*
part_grad_gamma_ptr
=
part_grad_gamma
+
threadIdx
.
y
*
num_warp_reductions
*
n2
+
i2
;
const
U
*
part_grad_gamma_ptr
=
const
U
*
part_grad_beta_ptr
=
part_grad_beta
+
threadIdx
.
y
*
num_warp_reductions
*
n2
+
i2
;
part_grad_gamma
+
threadIdx
.
y
*
num_warp_reductions
*
n2
+
i2
;
for
(
int
warp_offset
=
0
;
warp_offset
<
num_warp_reductions
;
++
warp_offset
)
{
const
U
*
part_grad_beta_ptr
=
sum_gamma
+=
part_grad_gamma_ptr
[
warp_offset
*
n2
];
part_grad_beta
+
threadIdx
.
y
*
num_warp_reductions
*
n2
+
i2
;
sum_beta
+=
part_grad_beta_ptr
[
warp_offset
*
n2
];
for
(
int
warp_offset
=
0
;
warp_offset
<
num_warp_reductions
;
++
warp_offset
)
{
sum_gamma
+=
part_grad_gamma_ptr
[
warp_offset
*
n2
];
sum_beta
+=
part_grad_beta_ptr
[
warp_offset
*
n2
];
}
}
// inter-warp reductions
// inter-warp reductions
const
int
nbsize3
=
blockDim
.
x
*
blockDim
.
y
/
2
;
const
int
nbsize3
=
blockDim
.
x
*
blockDim
.
y
/
2
;
for
(
int
offset
=
blockDim
.
y
/
2
;
offset
>=
1
;
offset
/=
2
)
{
for
(
int
offset
=
blockDim
.
y
/
2
;
offset
>=
1
;
offset
/=
2
)
{
// top half write to shared memory
// top half write to shared memory
if
(
threadIdx
.
y
>=
offset
&&
threadIdx
.
y
<
2
*
offset
)
{
if
(
threadIdx
.
y
>=
offset
&&
threadIdx
.
y
<
2
*
offset
)
{
const
int
write_idx
=
(
threadIdx
.
y
-
offset
)
*
blockDim
.
x
+
threadIdx
.
x
;
const
int
write_idx
=
(
threadIdx
.
y
-
offset
)
*
blockDim
.
x
+
threadIdx
.
x
;
buf
[
write_idx
]
=
sum_gamma
;
buf
[
write_idx
]
=
sum_gamma
;
buf
[
write_idx
+
nbsize3
]
=
sum_beta
;
buf
[
write_idx
+
nbsize3
]
=
sum_beta
;
}
}
__syncthreads
();
__syncthreads
();
// bottom half sums
// bottom half sums
if
(
threadIdx
.
y
<
offset
)
{
if
(
threadIdx
.
y
<
offset
)
{
const
int
read_idx
=
threadIdx
.
y
*
blockDim
.
x
+
threadIdx
.
x
;
const
int
read_idx
=
threadIdx
.
y
*
blockDim
.
x
+
threadIdx
.
x
;
sum_gamma
+=
buf
[
read_idx
];
sum_gamma
+=
buf
[
read_idx
];
sum_beta
+=
buf
[
read_idx
+
nbsize3
];
sum_beta
+=
buf
[
read_idx
+
nbsize3
];
}
}
__syncthreads
();
__syncthreads
();
}
}
...
@@ -527,51 +473,46 @@ void cuComputeGradGammaBeta(
...
@@ -527,51 +473,46 @@ void cuComputeGradGammaBeta(
}
}
}
}
template
<
typename
T
,
typename
U
>
__global__
template
<
typename
T
,
typename
U
>
void
cuComputeGradInput
(
__global__
void
const
T
*
__restrict__
dout
,
cuComputeGradInput
(
const
T
*
__restrict__
dout
,
const
T
*
__restrict__
dout_resid
,
const
T
*
__restrict__
dout_resid
,
const
T
*
__restrict__
input
,
const
int
n1
,
const
int
n2
,
const
T
*
__restrict__
input
,
const
U
*
__restrict__
mean
,
const
U
*
__restrict__
invvar
,
const
int
n1
,
U
epsilon
,
const
T
*
gamma
,
T
*
grad_input
)
{
const
int
n2
,
for
(
int
i1
=
blockIdx
.
y
;
i1
<
n1
;
i1
+=
gridDim
.
y
)
{
const
U
*
__restrict__
mean
,
const
U
*
__restrict__
invvar
,
U
epsilon
,
const
T
*
gamma
,
T
*
grad_input
)
{
for
(
int
i1
=
blockIdx
.
y
;
i1
<
n1
;
i1
+=
gridDim
.
y
)
{
U
sum_loss1
=
U
(
0
);
U
sum_loss1
=
U
(
0
);
U
sum_loss2
=
U
(
0
);
U
sum_loss2
=
U
(
0
);
const
U
c_mean
=
mean
[
i1
];
const
U
c_mean
=
mean
[
i1
];
const
U
c_invvar
=
invvar
[
i1
];
const
U
c_invvar
=
invvar
[
i1
];
const
T
*
k_input
=
input
+
i1
*
n2
;
const
T
*
k_input
=
input
+
i1
*
n2
;
const
T
*
k_dout
=
dout
+
i1
*
n2
;
const
T
*
k_dout
=
dout
+
i1
*
n2
;
const
T
*
k_dout_resid
=
dout_resid
+
i1
*
n2
;
const
T
*
k_dout_resid
=
dout_resid
+
i1
*
n2
;
const
int
numx
=
blockDim
.
x
*
blockDim
.
y
;
const
int
numx
=
blockDim
.
x
*
blockDim
.
y
;
const
int
thrx
=
threadIdx
.
x
+
threadIdx
.
y
*
blockDim
.
x
;
const
int
thrx
=
threadIdx
.
x
+
threadIdx
.
y
*
blockDim
.
x
;
if
(
gamma
!=
NULL
)
{
if
(
gamma
!=
NULL
)
{
int
l
=
4
*
thrx
;
int
l
=
4
*
thrx
;
for
(;
l
+
3
<
n2
;
l
+=
4
*
numx
)
{
for
(;
l
+
3
<
n2
;
l
+=
4
*
numx
)
{
for
(
int
k
=
0
;
k
<
4
;
++
k
)
{
for
(
int
k
=
0
;
k
<
4
;
++
k
)
{
const
U
c_h
=
static_cast
<
U
>
(
k_input
[
l
+
k
]);
const
U
c_h
=
static_cast
<
U
>
(
k_input
[
l
+
k
]);
const
U
c_loss
=
static_cast
<
U
>
(
k_dout
[
l
+
k
]);
const
U
c_loss
=
static_cast
<
U
>
(
k_dout
[
l
+
k
]);
sum_loss1
+=
c_loss
*
static_cast
<
U
>
(
gamma
[
l
+
k
]);
sum_loss1
+=
c_loss
*
static_cast
<
U
>
(
gamma
[
l
+
k
]);
sum_loss2
+=
c_loss
*
static_cast
<
U
>
(
gamma
[
l
+
k
])
*
(
c_h
-
c_mean
)
*
c_invvar
;
sum_loss2
+=
c_loss
*
static_cast
<
U
>
(
gamma
[
l
+
k
])
*
(
c_h
-
c_mean
)
*
c_invvar
;
}
}
}
}
for
(;
l
<
n2
;
++
l
)
{
for
(;
l
<
n2
;
++
l
)
{
const
U
c_h
=
static_cast
<
U
>
(
k_input
[
l
]);
const
U
c_h
=
static_cast
<
U
>
(
k_input
[
l
]);
const
U
c_loss
=
static_cast
<
U
>
(
k_dout
[
l
]);
const
U
c_loss
=
static_cast
<
U
>
(
k_dout
[
l
]);
sum_loss1
+=
c_loss
*
static_cast
<
U
>
(
gamma
[
l
]);
sum_loss1
+=
c_loss
*
static_cast
<
U
>
(
gamma
[
l
]);
sum_loss2
+=
c_loss
*
static_cast
<
U
>
(
gamma
[
l
])
*
(
c_h
-
c_mean
)
*
c_invvar
;
sum_loss2
+=
c_loss
*
static_cast
<
U
>
(
gamma
[
l
])
*
(
c_h
-
c_mean
)
*
c_invvar
;
}
}
}
else
{
}
else
{
int
l
=
4
*
thrx
;
int
l
=
4
*
thrx
;
for
(;
l
+
3
<
n2
;
l
+=
4
*
numx
)
{
for
(;
l
+
3
<
n2
;
l
+=
4
*
numx
)
{
for
(
int
k
=
0
;
k
<
4
;
++
k
)
{
for
(
int
k
=
0
;
k
<
4
;
++
k
)
{
const
U
c_h
=
static_cast
<
U
>
(
k_input
[
l
+
k
]);
const
U
c_h
=
static_cast
<
U
>
(
k_input
[
l
+
k
]);
const
U
c_loss
=
static_cast
<
U
>
(
k_dout
[
l
+
k
]);
const
U
c_loss
=
static_cast
<
U
>
(
k_dout
[
l
+
k
]);
sum_loss1
+=
c_loss
;
sum_loss1
+=
c_loss
;
sum_loss2
+=
c_loss
*
(
c_h
-
c_mean
)
*
c_invvar
;
sum_loss2
+=
c_loss
*
(
c_h
-
c_mean
)
*
c_invvar
;
}
}
...
@@ -591,161 +532,121 @@ void cuComputeGradInput(
...
@@ -591,161 +532,121 @@ void cuComputeGradInput(
// inter-warp reductions
// inter-warp reductions
if
(
blockDim
.
y
>
1
)
{
if
(
blockDim
.
y
>
1
)
{
SharedMemory
<
U
>
shared
;
SharedMemory
<
U
>
shared
;
U
*
buf
=
shared
.
getPointer
();
U
*
buf
=
shared
.
getPointer
();
for
(
int
offset
=
blockDim
.
y
/
2
;
offset
>
0
;
offset
/=
2
)
{
for
(
int
offset
=
blockDim
.
y
/
2
;
offset
>
0
;
offset
/=
2
)
{
// upper half of warps write to shared
// upper half of warps write to shared
if
(
threadIdx
.
y
>=
offset
&&
threadIdx
.
y
<
2
*
offset
)
{
if
(
threadIdx
.
y
>=
offset
&&
threadIdx
.
y
<
2
*
offset
)
{
const
int
wrt_i
=
(
threadIdx
.
y
-
offset
)
*
blockDim
.
x
+
threadIdx
.
x
;
const
int
wrt_i
=
(
threadIdx
.
y
-
offset
)
*
blockDim
.
x
+
threadIdx
.
x
;
buf
[
2
*
wrt_i
]
=
sum_loss1
;
buf
[
2
*
wrt_i
]
=
sum_loss1
;
buf
[
2
*
wrt_i
+
1
]
=
sum_loss2
;
buf
[
2
*
wrt_i
+
1
]
=
sum_loss2
;
}
}
__syncthreads
();
__syncthreads
();
// lower half merges
// lower half merges
if
(
threadIdx
.
y
<
offset
)
{
if
(
threadIdx
.
y
<
offset
)
{
const
int
read_i
=
threadIdx
.
y
*
blockDim
.
x
+
threadIdx
.
x
;
const
int
read_i
=
threadIdx
.
y
*
blockDim
.
x
+
threadIdx
.
x
;
sum_loss1
+=
buf
[
2
*
read_i
];
sum_loss1
+=
buf
[
2
*
read_i
];
sum_loss2
+=
buf
[
2
*
read_i
+
1
];
sum_loss2
+=
buf
[
2
*
read_i
+
1
];
}
}
__syncthreads
();
__syncthreads
();
}
}
if
(
threadIdx
.
y
==
0
)
{
if
(
threadIdx
.
y
==
0
)
{
buf
[
2
*
threadIdx
.
x
]
=
sum_loss1
;
buf
[
2
*
threadIdx
.
x
]
=
sum_loss1
;
buf
[
2
*
threadIdx
.
x
+
1
]
=
sum_loss2
;
buf
[
2
*
threadIdx
.
x
+
1
]
=
sum_loss2
;
}
}
__syncthreads
();
__syncthreads
();
if
(
threadIdx
.
y
!=
0
)
{
if
(
threadIdx
.
y
!=
0
)
{
sum_loss1
=
buf
[
2
*
threadIdx
.
x
];
sum_loss1
=
buf
[
2
*
threadIdx
.
x
];
sum_loss2
=
buf
[
2
*
threadIdx
.
x
+
1
];
sum_loss2
=
buf
[
2
*
threadIdx
.
x
+
1
];
}
}
}
}
// all threads now have the two sums over l
// all threads now have the two sums over l
U
fH
=
(
U
)
n2
;
U
fH
=
(
U
)
n2
;
U
term1
=
(
U
(
1
)
/
fH
)
*
c_invvar
;
U
term1
=
(
U
(
1
)
/
fH
)
*
c_invvar
;
T
*
k_grad_input
=
grad_input
+
i1
*
n2
;
T
*
k_grad_input
=
grad_input
+
i1
*
n2
;
if
(
gamma
!=
NULL
)
{
if
(
gamma
!=
NULL
)
{
for
(
int
l
=
thrx
;
l
<
n2
;
l
+=
numx
)
{
for
(
int
l
=
thrx
;
l
<
n2
;
l
+=
numx
)
{
const
U
c_h
=
static_cast
<
U
>
(
k_input
[
l
]);
const
U
c_h
=
static_cast
<
U
>
(
k_input
[
l
]);
const
U
c_loss
=
static_cast
<
U
>
(
k_dout
[
l
]);
const
U
c_loss
=
static_cast
<
U
>
(
k_dout
[
l
]);
const
T
c_resid
=
static_cast
<
T
>
(
k_dout_resid
[
l
]);
const
T
c_resid
=
static_cast
<
T
>
(
k_dout_resid
[
l
]);
U
f_grad_input
=
fH
*
c_loss
*
static_cast
<
U
>
(
gamma
[
l
]);
U
f_grad_input
=
fH
*
c_loss
*
static_cast
<
U
>
(
gamma
[
l
]);
f_grad_input
-=
sum_loss1
;
f_grad_input
-=
sum_loss1
;
f_grad_input
-=
(
c_h
-
c_mean
)
*
c_invvar
*
sum_loss2
;
f_grad_input
-=
(
c_h
-
c_mean
)
*
c_invvar
*
sum_loss2
;
f_grad_input
*=
term1
;
f_grad_input
*=
term1
;
k_grad_input
[
l
]
=
static_cast
<
T
>
(
f_grad_input
)
+
c_resid
;
k_grad_input
[
l
]
=
static_cast
<
T
>
(
f_grad_input
)
+
c_resid
;
}
}
}
else
{
}
else
{
for
(
int
l
=
thrx
;
l
<
n2
;
l
+=
numx
)
{
for
(
int
l
=
thrx
;
l
<
n2
;
l
+=
numx
)
{
const
U
c_h
=
static_cast
<
U
>
(
k_input
[
l
]);
const
U
c_h
=
static_cast
<
U
>
(
k_input
[
l
]);
const
U
c_loss
=
static_cast
<
U
>
(
k_dout
[
l
]);
const
U
c_loss
=
static_cast
<
U
>
(
k_dout
[
l
]);
const
T
c_resid
=
static_cast
<
T
>
(
k_dout_resid
[
l
]);
const
T
c_resid
=
static_cast
<
T
>
(
k_dout_resid
[
l
]);
U
f_grad_input
=
fH
*
c_loss
;
U
f_grad_input
=
fH
*
c_loss
;
f_grad_input
-=
sum_loss1
;
f_grad_input
-=
sum_loss1
;
f_grad_input
-=
(
c_h
-
c_mean
)
*
c_invvar
*
sum_loss2
;
f_grad_input
-=
(
c_h
-
c_mean
)
*
c_invvar
*
sum_loss2
;
f_grad_input
*=
term1
;
f_grad_input
*=
term1
;
k_grad_input
[
l
]
=
static_cast
<
T
>
(
f_grad_input
)
+
c_resid
;
k_grad_input
[
l
]
=
static_cast
<
T
>
(
f_grad_input
)
+
c_resid
;
}
}
}
}
}
}
}
}
template
<
typename
T
,
typename
U
>
template
<
typename
T
,
typename
U
>
void
HostApplyLayerNorm
(
void
HostApplyLayerNorm
(
T
*
output
,
U
*
mean
,
U
*
invvar
,
const
T
*
input
,
int
n1
,
T
*
output
,
int
n2
,
double
epsilon
,
const
T
*
gamma
,
const
T
*
beta
)
{
U
*
mean
,
U
*
invvar
,
const
T
*
input
,
int
n1
,
int
n2
,
double
epsilon
,
const
T
*
gamma
,
const
T
*
beta
)
{
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
const
dim3
threads
(
32
,
4
,
1
);
const
dim3
threads
(
32
,
4
,
1
);
const
uint64_t
maxGridY
=
at
::
cuda
::
getCurrentDeviceProperties
()
->
maxGridSize
[
1
];
const
uint64_t
maxGridY
=
at
::
cuda
::
getCurrentDeviceProperties
()
->
maxGridSize
[
1
];
const
dim3
blocks
(
1
,
std
::
min
((
uint64_t
)
n1
,
maxGridY
),
1
);
const
dim3
blocks
(
1
,
std
::
min
((
uint64_t
)
n1
,
maxGridY
),
1
);
int
nshared
=
int
nshared
=
threads
.
y
>
1
?
threads
.
y
>
1
?
threads
.
y
*
sizeof
(
U
)
+
(
threads
.
y
/
2
)
*
sizeof
(
U
)
:
0
;
threads
.
y
*
sizeof
(
U
)
+
(
threads
.
y
/
2
)
*
sizeof
(
U
)
:
0
;
cuApplyLayerNorm
<<<
blocks
,
threads
,
nshared
,
stream
>>>
(
cuApplyLayerNorm
<<<
blocks
,
threads
,
nshared
,
stream
>>>
(
output
,
output
,
mean
,
invvar
,
input
,
n1
,
n2
,
U
(
epsilon
),
gamma
,
beta
);
mean
,
invvar
,
input
,
n1
,
n2
,
U
(
epsilon
),
gamma
,
beta
);
}
}
template
<
typename
T
,
typename
U
>
template
<
typename
T
,
typename
U
>
void
HostLayerNormGradient
(
void
HostLayerNormGradient
(
const
T
*
dout
,
const
T
*
dout_resid
,
const
U
*
mean
,
const
T
*
dout
,
const
U
*
invvar
,
const
at
::
Tensor
&
input
,
int
n1
,
const
T
*
dout_resid
,
int
n2
,
const
T
*
gamma
,
const
T
*
beta
,
const
U
*
mean
,
double
epsilon
,
T
*
grad_input
,
T
*
grad_gamma
,
const
U
*
invvar
,
T
*
grad_beta
)
{
const
at
::
Tensor
&
input
,
int
n1
,
int
n2
,
const
T
*
gamma
,
const
T
*
beta
,
double
epsilon
,
T
*
grad_input
,
T
*
grad_gamma
,
T
*
grad_beta
)
{
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
if
(
gamma
!=
NULL
&&
beta
!=
NULL
)
{
if
(
gamma
!=
NULL
&&
beta
!=
NULL
)
{
// compute grad_gamma(j) and grad_beta(j)
// compute grad_gamma(j) and grad_beta(j)
const
int
part_size
=
16
;
const
int
part_size
=
16
;
const
dim3
threads2
(
32
,
4
,
1
);
const
dim3
threads2
(
32
,
4
,
1
);
const
dim3
blocks2
((
n2
+
threads2
.
x
-
1
)
/
threads2
.
x
,
part_size
,
1
);
const
dim3
blocks2
((
n2
+
threads2
.
x
-
1
)
/
threads2
.
x
,
part_size
,
1
);
const
int
nshared2_a
=
2
*
sizeof
(
U
)
*
threads2
.
y
*
threads2
.
y
*
(
threads2
.
x
+
1
);
const
int
nshared2_a
=
2
*
sizeof
(
U
)
*
threads2
.
y
*
threads2
.
y
*
(
threads2
.
x
+
1
);
const
int
nshared2_b
=
threads2
.
x
*
threads2
.
y
*
sizeof
(
U
);
const
int
nshared2_b
=
threads2
.
x
*
threads2
.
y
*
sizeof
(
U
);
const
int
nshared2
=
nshared2_a
>
nshared2_b
?
nshared2_a
:
nshared2_b
;
const
int
nshared2
=
nshared2_a
>
nshared2_b
?
nshared2_a
:
nshared2_b
;
at
::
Tensor
part_grad_gamma
=
at
::
empty
({
part_size
,
n2
},
input
.
options
().
dtype
(
input
.
scalar_type
()
==
at
::
ScalarType
::
Half
?
at
::
ScalarType
::
Float
:
input
.
scalar_type
()));
at
::
Tensor
part_grad_gamma
=
at
::
empty
(
{
part_size
,
n2
},
input
.
options
().
dtype
(
input
.
scalar_type
()
==
at
::
ScalarType
::
Half
?
at
::
ScalarType
::
Float
:
input
.
scalar_type
()));
at
::
Tensor
part_grad_beta
=
at
::
empty_like
(
part_grad_gamma
);
at
::
Tensor
part_grad_beta
=
at
::
empty_like
(
part_grad_gamma
);
cuComputePartGradGammaBeta
<<<
blocks2
,
threads2
,
nshared2
,
stream
>>>
(
cuComputePartGradGammaBeta
<<<
blocks2
,
threads2
,
nshared2
,
stream
>>>
(
dout
,
dout
,
static_cast
<
T
*>
(
input
.
data_ptr
()),
n1
,
n2
,
mean
,
invvar
,
static_cast
<
T
*>
(
input
.
data_ptr
()),
U
(
epsilon
),
static_cast
<
U
*>
(
part_grad_gamma
.
data_ptr
()),
n1
,
n2
,
static_cast
<
U
*>
(
part_grad_beta
.
data_ptr
()));
mean
,
invvar
,
U
(
epsilon
),
static_cast
<
U
*>
(
part_grad_gamma
.
data_ptr
()),
static_cast
<
U
*>
(
part_grad_beta
.
data_ptr
()));
const
dim3
threads3
(
32
,
8
,
1
);
const
dim3
threads3
(
32
,
8
,
1
);
const
dim3
blocks3
((
n2
+
threads2
.
x
-
1
)
/
threads2
.
x
,
1
,
1
);
const
dim3
blocks3
((
n2
+
threads2
.
x
-
1
)
/
threads2
.
x
,
1
,
1
);
const
int
nshared3
=
threads3
.
x
*
threads3
.
y
*
sizeof
(
U
);
const
int
nshared3
=
threads3
.
x
*
threads3
.
y
*
sizeof
(
U
);
cuComputeGradGammaBeta
<<<
blocks3
,
threads3
,
nshared3
,
stream
>>>
(
cuComputeGradGammaBeta
<<<
blocks3
,
threads3
,
nshared3
,
stream
>>>
(
static_cast
<
U
*>
(
part_grad_gamma
.
data_ptr
()),
static_cast
<
U
*>
(
part_grad_gamma
.
data_ptr
()),
static_cast
<
U
*>
(
part_grad_beta
.
data_ptr
()),
static_cast
<
U
*>
(
part_grad_beta
.
data_ptr
()),
part_size
,
n1
,
n2
,
part_size
,
grad_gamma
,
grad_beta
);
n1
,
n2
,
grad_gamma
,
grad_beta
);
}
}
// compute grad_input
// compute grad_input
const
uint64_t
maxGridY
=
at
::
cuda
::
getCurrentDeviceProperties
()
->
maxGridSize
[
1
];
const
uint64_t
maxGridY
=
at
::
cuda
::
getCurrentDeviceProperties
()
->
maxGridSize
[
1
];
const
dim3
blocks1
(
1
,
std
::
min
((
uint64_t
)
n1
,
maxGridY
),
1
);
const
dim3
blocks1
(
1
,
std
::
min
((
uint64_t
)
n1
,
maxGridY
),
1
);
const
dim3
threads1
(
32
,
4
,
1
);
const
dim3
threads1
(
32
,
4
,
1
);
int
nshared
=
int
nshared
=
threads1
.
y
>
1
?
threads1
.
y
*
threads1
.
x
*
sizeof
(
U
)
:
0
;
threads1
.
y
>
1
?
threads1
.
y
*
threads1
.
x
*
sizeof
(
U
)
:
0
;
cuComputeGradInput
<<<
blocks1
,
threads1
,
nshared
,
stream
>>>
(
cuComputeGradInput
<<<
blocks1
,
threads1
,
nshared
,
stream
>>>
(
dout
,
dout
,
dout_resid
,
static_cast
<
T
*>
(
input
.
data_ptr
()),
n1
,
n2
,
mean
,
dout_resid
,
invvar
,
U
(
epsilon
),
gamma
,
grad_input
);
static_cast
<
T
*>
(
input
.
data_ptr
()),
n1
,
n2
,
mean
,
invvar
,
U
(
epsilon
),
gamma
,
grad_input
);
}
}
apex/contrib/csrc/multihead_attn/masked_softmax_dropout_cpp.cpp
View file @
db92ee13
...
@@ -5,81 +5,66 @@ namespace multihead_attn {
...
@@ -5,81 +5,66 @@ namespace multihead_attn {
namespace
fused_softmax
{
namespace
fused_softmax
{
namespace
mask_softmax_dropout
{
namespace
mask_softmax_dropout
{
std
::
vector
<
torch
::
Tensor
>
fwd_cuda
(
std
::
vector
<
torch
::
Tensor
>
fwd_cuda
(
bool
is_training
,
int
heads
,
bool
is_training
,
torch
::
Tensor
const
&
input
,
int
heads
,
const
uint8_t
*
pad_mask
,
torch
::
Tensor
const
&
input
,
float
dropout_prob
);
const
uint8_t
*
pad_mask
,
float
dropout_prob
);
torch
::
Tensor
bwd_cuda
(
torch
::
Tensor
bwd_cuda
(
int
heads
,
torch
::
Tensor
const
&
output_grads
,
int
heads
,
torch
::
Tensor
const
&
softmax_results
,
torch
::
Tensor
const
&
output_grads
,
torch
::
Tensor
const
&
dropout_mask
,
torch
::
Tensor
const
&
softmax_results
,
const
uint8_t
*
padding_mask
,
float
dropout_prob
);
torch
::
Tensor
const
&
dropout_mask
,
const
uint8_t
*
padding_mask
,
float
dropout_prob
);
// C++ interface
// C++ interface
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CUDA(x) \
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
#define CHECK_CONTIGUOUS(x) \
AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
std
::
vector
<
torch
::
Tensor
>
fwd
(
std
::
vector
<
torch
::
Tensor
>
fwd
(
bool
use_mask
,
bool
is_training
,
int
heads
,
bool
use_mask
,
torch
::
Tensor
const
&
input
,
bool
is_training
,
torch
::
Tensor
const
&
pad_mask
,
int
heads
,
float
dropout_prob
)
{
torch
::
Tensor
const
&
input
,
torch
::
Tensor
const
&
pad_mask
,
float
dropout_prob
)
{
AT_ASSERTM
(
input
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
input
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
input
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
input
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
if
(
use_mask
)
{
if
(
use_mask
)
{
AT_ASSERTM
(
pad_mask
.
dim
()
==
2
,
"expected 2D tensor"
);
AT_ASSERTM
(
pad_mask
.
dim
()
==
2
,
"expected 2D tensor"
);
AT_ASSERTM
(
pad_mask
.
type
().
scalarType
()
==
at
::
ScalarType
::
Byte
,
"Only BYTE is supported"
);
AT_ASSERTM
(
pad_mask
.
type
().
scalarType
()
==
at
::
ScalarType
::
Byte
,
"Only BYTE is supported"
);
}
}
return
fwd_cuda
(
return
fwd_cuda
(
is_training
,
heads
,
input
,
is_training
,
use_mask
?
static_cast
<
const
uint8_t
*>
(
pad_mask
.
data_ptr
())
heads
,
:
nullptr
,
input
,
dropout_prob
);
use_mask
?
static_cast
<
const
uint8_t
*>
(
pad_mask
.
data_ptr
())
:
nullptr
,
dropout_prob
);
}
}
torch
::
Tensor
bwd
(
torch
::
Tensor
bwd
(
bool
use_mask
,
int
heads
,
torch
::
Tensor
const
&
output_grads
,
bool
use_mask
,
torch
::
Tensor
const
&
softmax_results
,
int
heads
,
torch
::
Tensor
const
&
dropout_mask
,
torch
::
Tensor
const
&
output_grads
,
torch
::
Tensor
const
&
padding_mask
,
float
dropout_prob
)
{
torch
::
Tensor
const
&
softmax_results
,
torch
::
Tensor
const
&
dropout_mask
,
torch
::
Tensor
const
&
padding_mask
,
float
dropout_prob
)
{
AT_ASSERTM
(
output_grads
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
output_grads
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
softmax_results
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
softmax_results
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
dropout_mask
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
dropout_mask
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
output_grads
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
output_grads
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
AT_ASSERTM
(
softmax_results
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
"Only HALF is supported"
);
// AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte, "Only BYTE is supported");
AT_ASSERTM
(
softmax_results
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
// AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte,
// "Only BYTE is supported");
return
bwd_cuda
(
return
bwd_cuda
(
heads
,
output_grads
,
softmax_results
,
dropout_mask
,
heads
,
use_mask
output_grads
,
?
static_cast
<
const
uint8_t
*>
(
padding_mask
.
data_ptr
())
softmax_results
,
:
nullptr
,
dropout_mask
,
dropout_prob
);
use_mask
?
static_cast
<
const
uint8_t
*>
(
padding_mask
.
data_ptr
())
:
nullptr
,
dropout_prob
);
}
}
}
// end namespace mask_softmax_dropout
}
// end namespace mask_softmax_dropout
...
@@ -87,7 +72,8 @@ torch::Tensor bwd(
...
@@ -87,7 +72,8 @@ torch::Tensor bwd(
}
// end namespace multihead_attn
}
// end namespace multihead_attn
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"forward"
,
&
multihead_attn
::
fused_softmax
::
mask_softmax_dropout
::
fwd
,
"Self Multihead Attention masked softmax dropout -- Forward."
);
m
.
def
(
"forward"
,
&
multihead_attn
::
fused_softmax
::
mask_softmax_dropout
::
fwd
,
m
.
def
(
"backward"
,
&
multihead_attn
::
fused_softmax
::
mask_softmax_dropout
::
bwd
,
"Self Multihead Attention masked softmax dropout -- Backward."
);
"Self Multihead Attention masked softmax dropout -- Forward."
);
m
.
def
(
"backward"
,
&
multihead_attn
::
fused_softmax
::
mask_softmax_dropout
::
bwd
,
"Self Multihead Attention masked softmax dropout -- Backward."
);
}
}
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment