Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
c2b62b7f
Commit
c2b62b7f
authored
Mar 13, 2025
by
JR_ZZU
🌴
Browse files
delete origin files
parent
2a4864d5
Changes
164
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
0 additions
and
10749 deletions
+0
-10749
apex/contrib/csrc/layer_norm/ln_api.cpp
apex/contrib/csrc/layer_norm/ln_api.cpp
+0
-246
apex/contrib/csrc/layer_norm/ln_bwd_kernels.cuh
apex/contrib/csrc/layer_norm/ln_bwd_kernels.cuh
+0
-315
apex/contrib/csrc/layer_norm/ln_bwd_semi_cuda_kernel.cu
apex/contrib/csrc/layer_norm/ln_bwd_semi_cuda_kernel.cu
+0
-250
apex/contrib/csrc/layer_norm/ln_fwd_cuda_kernel.cu
apex/contrib/csrc/layer_norm/ln_fwd_cuda_kernel.cu
+0
-235
apex/contrib/csrc/layer_norm/ln_fwd_kernels.cuh
apex/contrib/csrc/layer_norm/ln_fwd_kernels.cuh
+0
-114
apex/contrib/csrc/layer_norm/ln_kernel_traits.h
apex/contrib/csrc/layer_norm/ln_kernel_traits.h
+0
-159
apex/contrib/csrc/layer_norm/ln_utils.cuh
apex/contrib/csrc/layer_norm/ln_utils.cuh
+0
-793
apex/contrib/csrc/multihead_attn/additive_masked_softmax_dropout_cuda.cu
...rc/multihead_attn/additive_masked_softmax_dropout_cuda.cu
+0
-113
apex/contrib/csrc/multihead_attn/dropout.cuh
apex/contrib/csrc/multihead_attn/dropout.cuh
+0
-272
apex/contrib/csrc/multihead_attn/encdec_multihead_attn_cuda.cu
...contrib/csrc/multihead_attn/encdec_multihead_attn_cuda.cu
+0
-611
apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add_cuda.cu
...src/multihead_attn/encdec_multihead_attn_norm_add_cuda.cu
+0
-690
apex/contrib/csrc/multihead_attn/layer_norm.cuh
apex/contrib/csrc/multihead_attn/layer_norm.cuh
+0
-649
apex/contrib/csrc/multihead_attn/masked_softmax_dropout_cuda.cu
...ontrib/csrc/multihead_attn/masked_softmax_dropout_cuda.cu
+0
-124
apex/contrib/csrc/multihead_attn/multihead_attn_frontend.cpp
apex/contrib/csrc/multihead_attn/multihead_attn_frontend.cpp
+0
-836
apex/contrib/csrc/multihead_attn/philox.cuh
apex/contrib/csrc/multihead_attn/philox.cuh
+0
-96
apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_additive_mask_cuda.cu
...ihead_attn/self_multihead_attn_bias_additive_mask_cuda.cu
+0
-504
apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_cuda.cu
...trib/csrc/multihead_attn/self_multihead_attn_bias_cuda.cu
+0
-504
apex/contrib/csrc/multihead_attn/self_multihead_attn_cuda.cu
apex/contrib/csrc/multihead_attn/self_multihead_attn_cuda.cu
+0
-509
apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add_cuda.cu
.../csrc/multihead_attn/self_multihead_attn_norm_add_cuda.cu
+0
-580
apex/contrib/csrc/multihead_attn/softmax.cuh
apex/contrib/csrc/multihead_attn/softmax.cuh
+0
-3149
No files found.
Too many changes to show.
To preserve performance only
164 of 164+
files are displayed.
Plain diff
Email patch
apex/contrib/csrc/layer_norm/ln_api.cpp
deleted
100644 → 0
View file @
2a4864d5
#include <torch/extension.h>
#include "ATen/cuda/CUDAContext.h"
#include "ln.h"
/*
Supported Type combinations:
input compute weights output
=======================================
fp32 fp32 fp32 fp32
fp16 fp32 fp16 fp16
bf16 fp32 bf16 bf16
fp32 fp32 fp16 fp16
fp32 fp32 bf16 bf16
Remarks:
Output type = Weight type
Compute always in FP32
*/
namespace
layer_norm
{
// Create registries and provide runtime versions of config hash functions.
// FwdRegistry FWD_FUNCS;
// BwdRegistry BWD_FUNCS;
////////////////////////////////////////////////////////////////////////////////////////////////////
uint32_t
get_type_id
(
torch
::
Dtype
dtype
){
if
(
dtype
==
torch
::
kFloat16
)
{
return
TypeId
<
fp16
>::
Value
;
}
else
if
(
dtype
==
torch
::
kBFloat16
)
{
return
TypeId
<
bf16
>::
Value
;
}
else
if
(
dtype
==
torch
::
kFloat32
)
{
return
TypeId
<
fp32
>::
Value
;
}
else
{
TORCH_CHECK
(
false
,
"Type not supported: "
,
dtype
);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
uint64_t
get_key
(
torch
::
Dtype
wtype
,
torch
::
Dtype
itype
,
torch
::
Dtype
otype
,
torch
::
Dtype
ctype
,
uint64_t
hidden_size
)
{
using
namespace
layer_norm
;
uint64_t
type_key
=
get_type_id
(
wtype
)
|
(
get_type_id
(
itype
)
<<
2
)
|
(
get_type_id
(
otype
)
<<
4
)
|
(
get_type_id
(
ctype
)
<<
6
);
uint64_t
launcher_key
=
(
type_key
<<
32
)
|
hidden_size
;
return
launcher_key
;
}
}
// namespace layer_norm
////////////////////////////////////////////////////////////////////////////////////////////////////
layer_norm
::
FwdFunction
&
get_fwd_launcher
(
torch
::
Dtype
wtype
,
torch
::
Dtype
itype
,
torch
::
Dtype
otype
,
torch
::
Dtype
ctype
,
uint32_t
hidden_size
)
{
auto
iter
=
layer_norm
::
FWD_FUNCS
.
find
(
layer_norm
::
get_key
(
wtype
,
itype
,
otype
,
ctype
,
hidden_size
));
if
(
iter
!=
layer_norm
::
FWD_FUNCS
.
end
()
)
{
return
iter
->
second
;
}
else
{
TORCH_CHECK
(
false
,
"FWD: Unsupported hidden_size or types: "
,
hidden_size
,
wtype
,
itype
,
otype
,
ctype
);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
layer_norm
::
BwdFunction
&
get_bwd_launcher
(
torch
::
Dtype
wtype
,
torch
::
Dtype
itype
,
torch
::
Dtype
otype
,
torch
::
Dtype
ctype
,
uint32_t
hidden_size
)
{
auto
iter
=
layer_norm
::
BWD_FUNCS
.
find
(
layer_norm
::
get_key
(
wtype
,
itype
,
otype
,
ctype
,
hidden_size
));
if
(
iter
!=
layer_norm
::
BWD_FUNCS
.
end
()
)
{
return
iter
->
second
;
}
else
{
TORCH_CHECK
(
false
,
"BWD: Unsupported hidden_size or types: "
,
hidden_size
,
wtype
,
itype
,
otype
,
ctype
);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
std
::
vector
<
at
::
Tensor
>
ln_fwd
(
const
at
::
Tensor
&
x
,
// BxSxhidden_size
const
at
::
Tensor
&
gamma
,
// hidden_size
const
at
::
Tensor
&
beta
,
// hidden_size
const
float
epsilon
)
{
auto
itype
=
x
.
scalar_type
();
auto
wtype
=
gamma
.
scalar_type
();
auto
otype
=
wtype
;
auto
ctype
=
torch
::
kFloat32
;
TORCH_CHECK
(
beta
.
scalar_type
()
==
wtype
);
TORCH_CHECK
(
x
.
is_cuda
())
TORCH_CHECK
(
gamma
.
is_cuda
())
TORCH_CHECK
(
beta
.
is_cuda
())
TORCH_CHECK
(
x
.
is_contiguous
());
auto
sizes
=
x
.
sizes
();
TORCH_CHECK
(
sizes
.
size
()
==
2
);
const
int
rows
=
sizes
[
0
];
const
int
cols
=
sizes
[
1
];
auto
hidden_size
=
gamma
.
numel
();
TORCH_CHECK
(
gamma
.
sizes
()
==
beta
.
sizes
());
TORCH_CHECK
(
hidden_size
==
cols
);
TORCH_CHECK
(
epsilon
>=
0.
f
);
auto
opts
=
x
.
options
();
auto
z
=
torch
::
empty
(
sizes
,
opts
.
dtype
(
otype
));
auto
mu
=
torch
::
empty
({
rows
},
opts
.
dtype
(
ctype
));
auto
rsigma
=
torch
::
empty
({
rows
},
opts
.
dtype
(
ctype
));
layer_norm
::
LaunchParams
<
layer_norm
::
FwdParams
>
launch_params
;
launch_params
.
props
=
at
::
cuda
::
getCurrentDeviceProperties
();
launch_params
.
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
// Request the kernel launcher.
auto
launcher
=
get_fwd_launcher
(
wtype
,
itype
,
otype
,
ctype
,
hidden_size
);
// Query the kernel-specific launch parameters.
launcher
(
launch_params
,
true
);
at
::
Tensor
workspace
,
barrier
;
// Set the kernel runtime parameters.
layer_norm
::
FwdParams
&
params
=
launch_params
.
params
;
params
.
rows
=
rows
;
params
.
cols
=
cols
;
params
.
x
=
x
.
data_ptr
();
params
.
mu
=
mu
.
data_ptr
();
params
.
rs
=
rsigma
.
data_ptr
();
params
.
gamma
=
gamma
.
data_ptr
();
params
.
beta
=
beta
.
data_ptr
();
params
.
z
=
z
.
data_ptr
();
params
.
epsilon
=
epsilon
;
if
(
launch_params
.
barrier_size
>
0
)
{
auto
options
=
x
.
options
();
barrier
=
torch
::
zeros
(
launch_params
.
barrier_size
,
options
.
dtype
(
torch
::
kInt32
));
workspace
=
torch
::
empty
(
launch_params
.
workspace_bytes
,
options
.
dtype
(
torch
::
kChar
));
params
.
workspace
=
workspace
.
data_ptr
();
params
.
barrier
=
barrier
.
data_ptr
<
int
>
();
}
// Launch the kernel.
launcher
(
launch_params
,
false
);
return
{
z
,
mu
,
rsigma
};
}
////////////////////////////////////////////////////////////////////////////////////////////////////
std
::
vector
<
at
::
Tensor
>
ln_bwd
(
const
at
::
Tensor
&
dz
,
// BxSxhidden_size
const
at
::
Tensor
&
x
,
// BxSxhidden_size
const
at
::
Tensor
&
mu
,
// BxS, FP32!
const
at
::
Tensor
&
rsigma
,
// BxS, FP32!
const
at
::
Tensor
&
gamma
// hidden_size
)
{
auto
itype
=
x
.
scalar_type
();
auto
wtype
=
gamma
.
scalar_type
();
auto
otype
=
wtype
;
auto
ctype
=
torch
::
kFloat32
;
TORCH_CHECK
(
dz
.
dtype
()
==
otype
);
TORCH_CHECK
(
mu
.
dtype
()
==
ctype
);
TORCH_CHECK
(
rsigma
.
dtype
()
==
ctype
);
TORCH_CHECK
(
x
.
is_cuda
());
TORCH_CHECK
(
dz
.
is_cuda
());
TORCH_CHECK
(
mu
.
is_cuda
());
TORCH_CHECK
(
rsigma
.
is_cuda
());
TORCH_CHECK
(
gamma
.
is_cuda
());
TORCH_CHECK
(
x
.
is_contiguous
());
TORCH_CHECK
(
dz
.
is_contiguous
());
auto
sizes
=
x
.
sizes
();
TORCH_CHECK
(
sizes
.
size
()
==
2
);
TORCH_CHECK
(
dz
.
sizes
()
==
sizes
);
auto
rows
=
sizes
[
0
];
auto
cols
=
sizes
[
1
];
auto
hidden_size
=
gamma
.
numel
();
TORCH_CHECK
(
mu
.
numel
()
==
rows
);
TORCH_CHECK
(
mu
.
sizes
()
==
rsigma
.
sizes
());
TORCH_CHECK
(
gamma
.
numel
()
==
cols
);
auto
options
=
x
.
options
();
auto
dx
=
torch
::
empty_like
(
x
);
auto
dgamma
=
torch
::
empty_like
(
gamma
);
auto
dbeta
=
torch
::
empty_like
(
gamma
);
layer_norm
::
LaunchParams
<
layer_norm
::
BwdParams
>
launch_params
;
launch_params
.
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
launch_params
.
props
=
at
::
cuda
::
getCurrentDeviceProperties
();
auto
launcher
=
get_bwd_launcher
(
wtype
,
itype
,
otype
,
ctype
,
hidden_size
);
launcher
(
launch_params
,
true
);
auto
dgamma_part
=
torch
::
empty
({
launch_params
.
params
.
ctas_per_col
,
hidden_size
},
options
.
dtype
(
ctype
));
auto
dbeta_part
=
torch
::
empty
({
launch_params
.
params
.
ctas_per_col
,
hidden_size
},
options
.
dtype
(
ctype
));
at
::
Tensor
workspace
,
barrier
;
layer_norm
::
BwdParams
&
params
=
launch_params
.
params
;
params
.
rows
=
rows
;
params
.
cols
=
cols
;
params
.
x
=
x
.
data_ptr
();
params
.
mu
=
mu
.
data_ptr
();
params
.
rs
=
rsigma
.
data_ptr
();
params
.
gamma
=
gamma
.
data_ptr
();
params
.
dz
=
dz
.
data_ptr
();
params
.
dx
=
dx
.
data_ptr
();
params
.
dbeta
=
dbeta
.
data_ptr
();
params
.
dgamma
=
dgamma
.
data_ptr
();
params
.
dbeta_part
=
dbeta_part
.
data_ptr
();
params
.
dgamma_part
=
dgamma_part
.
data_ptr
();
if
(
launch_params
.
barrier_size
>
0
)
{
// TODO Any way to avoid this?
barrier
=
torch
::
zeros
(
launch_params
.
barrier_size
,
options
.
dtype
(
torch
::
kInt32
));
workspace
=
torch
::
empty
(
launch_params
.
workspace_bytes
,
options
.
dtype
(
torch
::
kChar
));
params
.
workspace
=
workspace
.
data_ptr
();
params
.
barrier
=
barrier
.
data_ptr
<
int
>
();
}
launcher
(
launch_params
,
false
);
return
{
dx
,
dgamma
,
dbeta
,
dgamma_part
,
dbeta_part
};
}
////////////////////////////////////////////////////////////////////////////////////////////////////
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
doc
()
=
"CUDA LayerNorm"
;
m
.
def
(
"ln_fwd"
,
&
ln_fwd
,
"Run LayerNorm forward kernel"
);
m
.
def
(
"ln_bwd"
,
&
ln_bwd
,
"Run LayerNorm backward kernel"
);
}
apex/contrib/csrc/layer_norm/ln_bwd_kernels.cuh
deleted
100644 → 0
View file @
2a4864d5
#pragma once
namespace
layer_norm
{
template
<
typename
Ktraits
>
__global__
__launch_bounds__
(
Ktraits
::
THREADS_PER_CTA
)
void
ln_bwd_kernel
(
layer_norm
::
BwdParams
params
)
{
enum
{
ROWS_PER_CTA
=
Ktraits
::
ROWS_PER_CTA
};
enum
{
WARPS_M
=
Ktraits
::
WARPS_M
};
enum
{
WARPS_N
=
Ktraits
::
WARPS_N
};
enum
{
THREADS_PER_ROW
=
Ktraits
::
THREADS_PER_ROW
};
enum
{
COLS
=
Ktraits
::
COLS
};
enum
{
BYTES_PER_ROW
=
Ktraits
::
BYTES_PER_ROW
};
enum
{
LDGS
=
Ktraits
::
LDGS
};
enum
{
NUM_ELTS
=
Ktraits
::
ELTS_PER_LDG
};
enum
{
THREADS_PER_WARP
=
Ktraits
::
THREADS_PER_WARP
};
enum
{
CTAS_PER_ROW
=
Ktraits
::
CTAS_PER_ROW
};
using
compute_t
=
typename
Ktraits
::
compute_t
;
using
index_t
=
typename
Ktraits
::
index_t
;
using
Ivec
=
typename
Ktraits
::
Ivec
;
using
Ovec
=
typename
Ktraits
::
Ovec
;
using
Wvec
=
typename
Ktraits
::
Wvec
;
using
Cvec
=
typename
Ktraits
::
Cvec
;
using
Reducer
=
typename
Ktraits
::
Reducer
;
using
reduce_t
=
typename
Reducer
::
Type
;
extern
__shared__
char
smem_
[];
const
index_t
tidx
=
threadIdx
.
x
;
const
index_t
bidn
=
blockIdx
.
x
%
CTAS_PER_ROW
;
const
index_t
bidm
=
blockIdx
.
x
/
CTAS_PER_ROW
;
const
index_t
lane
=
tidx
%
THREADS_PER_WARP
;
const
index_t
warp
=
tidx
/
THREADS_PER_WARP
;
const
index_t
warp_m
=
warp
/
Ktraits
::
WARPS_N
;
const
index_t
warp_n
=
warp
%
Ktraits
::
WARPS_N
;
const
index_t
tid_r
=
warp_n
*
THREADS_PER_WARP
+
lane
;
const
index_t
r
=
bidm
*
Ktraits
::
ROWS_PER_CTA
+
warp_m
;
const
index_t
c
=
bidn
*
THREADS_PER_ROW
+
warp_n
*
THREADS_PER_WARP
+
lane
;
static_assert
(
COLS
==
THREADS_PER_ROW
*
LDGS
*
NUM_ELTS
*
CTAS_PER_ROW
);
Cvec
dzy_sum
[
LDGS
];
Cvec
dz_sum
[
LDGS
];
memset
(
dzy_sum
,
0
,
sizeof
(
dzy_sum
));
memset
(
dz_sum
,
0
,
sizeof
(
dz_sum
));
compute_t
*
smem_wgrad
=
reinterpret_cast
<
compute_t
*>
(
smem_
);
char
*
smem_dgrad
=
smem_
+
Ktraits
::
SMEM_BYTES_WGRAD
;
Reducer
reducer
(
params
,
bidm
,
bidn
,
warp_m
,
warp_n
,
lane
,
smem_dgrad
);
Sum
<
reduce_t
>
sum
;
constexpr
float
rn
=
1.
f
/
float
(
COLS
);
Wvec
gamma
[
LDGS
];
index_t
idx
=
c
;
#pragma unroll
for
(
int
it
=
0
;
it
<
LDGS
;
it
++
)
{
gamma
[
it
].
load_from
(
params
.
gamma
,
idx
);
idx
+=
Ktraits
::
VEC_COLS_PER_LDG
;
}
// TODO if ROWS_PER_CTA does not divide rows, we might get divergence in the
// last blocks with syncthreads!
// grid stride over rows
#pragma unroll 1
for
(
int
row
=
r
;
row
<
params
.
rows
;
row
+=
params
.
ctas_per_col
*
ROWS_PER_CTA
)
{
const
compute_t
mu_r
=
static_cast
<
const
compute_t
*>
(
params
.
mu
)[
row
];
const
compute_t
rs_r
=
static_cast
<
const
compute_t
*>
(
params
.
rs
)[
row
];
Ivec
x
[
LDGS
];
Ovec
dz
[
LDGS
];
index_t
idx
=
row
*
Ktraits
::
VEC_COLS
+
c
;
#pragma unroll
for
(
int
it
=
0
;
it
<
LDGS
;
it
++
)
{
dz
[
it
].
load_from
(
params
.
dz
,
idx
);
x
[
it
].
load_from
(
params
.
x
,
idx
);
idx
+=
Ktraits
::
VEC_COLS_PER_LDG
;
}
compute_t
dy
[
LDGS
*
NUM_ELTS
];
compute_t
y
[
LDGS
*
NUM_ELTS
];
compute_t
mdy_local
=
0.
f
;
compute_t
mdyy_local
=
0.
f
;
#pragma unroll
for
(
int
it
=
0
;
it
<
LDGS
;
it
++
)
{
#pragma unroll
for
(
int
jt
=
0
;
jt
<
NUM_ELTS
;
jt
++
)
{
compute_t
x_tmp
=
x
[
it
].
data
.
elt
[
jt
];
compute_t
y_tmp
=
rs_r
*
(
x_tmp
-
mu_r
);
compute_t
dy_tmp
=
compute_t
(
gamma
[
it
].
data
.
elt
[
jt
]);
dy_tmp
*=
compute_t
(
dz
[
it
].
data
.
elt
[
jt
]);
compute_t
dz_tmp
=
dz
[
it
].
data
.
elt
[
jt
];
mdy_local
+=
dy_tmp
;
mdyy_local
+=
dy_tmp
*
y_tmp
;
dy
[
it
*
NUM_ELTS
+
jt
]
=
dy_tmp
;
y
[
it
*
NUM_ELTS
+
jt
]
=
y_tmp
;
dzy_sum
[
it
].
data
.
elt
[
jt
]
+=
dz_tmp
*
y_tmp
;
dz_sum
[
it
].
data
.
elt
[
jt
]
+=
dz_tmp
;
}
}
reduce_t
result
=
reducer
.
allreduce
({
mdy_local
,
mdyy_local
},
sum
);
mdy_local
=
layer_norm
::
Get
<
0
>::
of
<
reduce_t
,
compute_t
>
(
result
)
*
rn
;
mdyy_local
=
layer_norm
::
Get
<
1
>::
of
<
reduce_t
,
compute_t
>
(
result
)
*
rn
;
Ivec
dx
[
LDGS
];
idx
=
row
*
Ktraits
::
VEC_COLS
+
c
;
#pragma unroll
for
(
int
it
=
0
;
it
<
LDGS
;
it
++
)
{
#pragma unroll
for
(
int
jt
=
0
;
jt
<
NUM_ELTS
;
jt
++
)
{
compute_t
dy_tmp
=
dy
[
it
*
NUM_ELTS
+
jt
];
compute_t
y_tmp
=
y
[
it
*
NUM_ELTS
+
jt
];
compute_t
dx_tmp
=
rs_r
*
(
dy_tmp
-
(
mdyy_local
*
y_tmp
+
mdy_local
));
dx
[
it
].
data
.
elt
[
jt
]
=
dx_tmp
;
}
dx
[
it
].
store_to
(
params
.
dx
,
idx
);
idx
+=
Ktraits
::
VEC_COLS_PER_LDG
;
}
}
// end: grid stride loop
if
(
WARPS_M
==
1
)
{
idx
=
r
*
Ktraits
::
VEC_COLS
+
c
;
#pragma unroll
for
(
int
it
=
0
;
it
<
LDGS
;
it
++
)
{
dz_sum
[
it
].
store_to
(
params
.
dbeta_part
,
idx
);
dzy_sum
[
it
].
store_to
(
params
.
dgamma_part
,
idx
);
idx
+=
Ktraits
::
VEC_COLS_PER_LDG
;
}
}
else
{
static_assert
(
WARPS_M
==
1
||
Ktraits
::
CTAS_PER_ROW
==
1
,
"Multiple rows per CTA not supported for Multi-CTA."
);
// Finalize reduction of part dgamma and dbeta for this CTA
// by reducing over the rows held across the WARPS_M warps
// Assumption: blockSize divides hidden size.
enum
{
NUM_RES
=
COLS
/
Ktraits
::
THREADS_PER_CTA
};
static_assert
(
NUM_RES
*
Ktraits
::
THREADS_PER_CTA
==
COLS
,
""
);
idx
=
warp_m
*
Ktraits
::
VEC_COLS
+
tid_r
;
#pragma unroll
for
(
int
it
=
0
;
it
<
LDGS
;
it
++
)
{
dz_sum
[
it
].
store_to
(
smem_wgrad
,
idx
);
idx
+=
THREADS_PER_ROW
;
}
__syncthreads
();
compute_t
cta_dz_sum
[
NUM_RES
];
memset
(
cta_dz_sum
,
0
,
sizeof
(
compute_t
)
*
NUM_RES
);
for
(
int
it
=
0
;
it
<
ROWS_PER_CTA
;
it
++
)
{
for
(
int
jt
=
0
;
jt
<
NUM_RES
;
jt
++
)
{
cta_dz_sum
[
jt
]
+=
smem_wgrad
[
it
*
COLS
+
tidx
+
jt
*
Ktraits
::
THREADS_PER_CTA
];
}
}
__syncthreads
();
idx
=
warp_m
*
Ktraits
::
VEC_COLS
+
tid_r
;
#pragma unroll
for
(
int
it
=
0
;
it
<
LDGS
;
it
++
)
{
dzy_sum
[
it
].
store_to
(
smem_wgrad
,
idx
);
idx
+=
THREADS_PER_ROW
;
}
__syncthreads
();
compute_t
cta_dzy_sum
[
NUM_RES
];
memset
(
cta_dzy_sum
,
0
,
sizeof
(
compute_t
)
*
NUM_RES
);
for
(
int
it
=
0
;
it
<
ROWS_PER_CTA
;
it
++
)
{
for
(
int
jt
=
0
;
jt
<
NUM_RES
;
jt
++
)
{
cta_dzy_sum
[
jt
]
+=
smem_wgrad
[
it
*
COLS
+
tidx
+
jt
*
Ktraits
::
THREADS_PER_CTA
];
}
}
compute_t
*
dgamma_part
=
static_cast
<
compute_t
*>
(
params
.
dgamma_part
)
+
bidm
*
COLS
+
tidx
;
for
(
int
jt
=
0
;
jt
<
NUM_RES
;
jt
++
)
{
*
dgamma_part
=
cta_dzy_sum
[
jt
];
dgamma_part
+=
Ktraits
::
THREADS_PER_CTA
;
}
compute_t
*
dbeta_part
=
static_cast
<
compute_t
*>
(
params
.
dbeta_part
)
+
bidm
*
COLS
+
tidx
;
for
(
int
jt
=
0
;
jt
<
NUM_RES
;
jt
++
)
{
*
dbeta_part
=
cta_dz_sum
[
jt
];
dbeta_part
+=
Ktraits
::
THREADS_PER_CTA
;
}
}
}
template
<
typename
Kernel_traits
>
__global__
__launch_bounds__
(
Kernel_traits
::
THREADS_PER_CTA
)
void
ln_bwd_finalize_kernel
(
BwdParams
params
)
{
using
compute_t
=
typename
Kernel_traits
::
compute_t
;
using
weight_t
=
typename
Kernel_traits
::
weight_t
;
using
index_t
=
typename
Kernel_traits
::
index_t
;
using
Reducer
=
typename
Kernel_traits
::
Reducer
;
using
reduce_t
=
typename
Reducer
::
Type
;
Sum
<
reduce_t
>
sum
;
enum
{
NUM_ELT
=
Kernel_traits
::
ELTS_PER_LDG
};
enum
{
THREADS_PER_WARP
=
Kernel_traits
::
THREADS_PER_WARP
};
__shared__
char
smem_
[
Kernel_traits
::
SMEM_BYTES_PER_CTA
];
constexpr
uint32_t
bidm
=
0
;
const
uint32_t
bidn
=
blockIdx
.
x
;
const
uint32_t
tidx
=
threadIdx
.
x
;
const
uint32_t
warp
=
tidx
/
THREADS_PER_WARP
;
const
uint32_t
lane
=
tidx
%
THREADS_PER_WARP
;
Reducer
reducer
(
params
,
bidm
,
bidn
,
0
,
0
,
lane
,
smem_
);
const
uint32_t
c
=
bidn
*
THREADS_PER_WARP
+
lane
;
const
uint32_t
c_out
=
bidn
*
THREADS_PER_WARP
/
2
+
lane
;
constexpr
uint32_t
COL_STRIDE
=
Kernel_traits
::
CTAS
*
THREADS_PER_WARP
;
for
(
uint32_t
col
=
c
,
col_out
=
c_out
;
col
<
Kernel_traits
::
COLS
;
col
+=
COL_STRIDE
,
col_out
+=
COL_STRIDE
/
2
)
{
// Each thread sums over NUM_ELT columns.
Vec
<
compute_t
,
NUM_ELT
>
dbeta_local
,
dgamma_local
;
memset
(
&
dgamma_local
,
0
,
sizeof
(
dgamma_local
));
memset
(
&
dbeta_local
,
0
,
sizeof
(
dbeta_local
));
for
(
uint32_t
row
=
warp
;
row
<
params
.
ctas_per_col
;
row
+=
Kernel_traits
::
ROWS_PER_CTA
)
{
index_t
idx
=
row
*
Kernel_traits
::
COLS
+
col
;
Vec
<
compute_t
,
NUM_ELT
>
dbeta_part
,
dgamma_part
;
dbeta_part
.
load_from
(
params
.
dbeta_part
,
idx
);
dgamma_part
.
load_from
(
params
.
dgamma_part
,
idx
);
#pragma unroll
for
(
int
it
=
0
;
it
<
NUM_ELT
;
it
++
)
{
dgamma_local
.
data
.
elt
[
it
]
+=
dgamma_part
.
data
.
elt
[
it
];
dbeta_local
.
data
.
elt
[
it
]
+=
dbeta_part
.
data
.
elt
[
it
];
}
}
void
*
smem_gamma
=
smem_
;
void
*
smem_beta
=
&
smem_
[
Kernel_traits
::
SMEM_BYTES_TRANSPOSE
];
const
int
write_row
=
warp
;
const
int
write_col
=
lane
^
write_row
;
const
int
write_idx
=
write_row
*
THREADS_PER_WARP
+
write_col
;
dgamma_local
.
store_to
(
smem_gamma
,
write_idx
);
dbeta_local
.
store_to
(
smem_beta
,
write_idx
);
__syncthreads
();
// It would be probably safe to reuse the first row of smem_beta and smem_gamma
void
*
smem_gamma_out
=
&
smem_
[
2
*
Kernel_traits
::
SMEM_BYTES_TRANSPOSE
];
void
*
smem_beta_out
=
&
smem_
[
2
*
Kernel_traits
::
SMEM_BYTES_TRANSPOSE
+
Kernel_traits
::
SMEM_BYTES_OUTPUT
];
// More than one iter iff ROWS_PER_CTA < 32.
for
(
int
w
=
warp
;
w
<
THREADS_PER_WARP
;
w
+=
Kernel_traits
::
ROWS_PER_CTA
)
{
const
int
read_row
=
lane
;
const
int
read_col
=
w
^
read_row
;
const
int
read_idx
=
read_row
*
THREADS_PER_WARP
+
read_col
;
memset
(
&
dbeta_local
,
0
,
sizeof
(
dbeta_local
));
memset
(
&
dgamma_local
,
0
,
sizeof
(
dgamma_local
));
// Load beta and gamma transposed
if
(
read_row
<
Kernel_traits
::
ROWS_PER_CTA
){
dbeta_local
.
load_from
(
smem_beta
,
read_idx
);
dgamma_local
.
load_from
(
smem_gamma
,
read_idx
);
}
// Call reducer on the loaded value(s) and convert.
#pragma unroll
for
(
int
it
=
0
;
it
<
NUM_ELT
;
it
++
)
{
compute_t
b_i
=
dbeta_local
.
data
.
elt
[
it
];
compute_t
g_i
=
dgamma_local
.
data
.
elt
[
it
];
b_i
=
reducer
.
allreduce
(
b_i
,
sum
);
g_i
=
reducer
.
allreduce
(
g_i
,
sum
);
dgamma_local
.
data
.
elt
[
it
]
=
g_i
;
dbeta_local
.
data
.
elt
[
it
]
=
b_i
;
}
// Leader stores the result at the current column.
if
(
lane
==
0
){
dgamma_local
.
store_to
(
smem_gamma_out
,
w
);
dbeta_local
.
store_to
(
smem_beta_out
,
w
);
}
}
// All writes done.
__syncthreads
();
// Pack and store: 2-wide stores with half the threads.
if
(
warp
==
Kernel_traits
::
ROWS_PER_CTA
-
1
&&
lane
<
THREADS_PER_WARP
/
2
)
{
using
src_t
=
typename
TypeToVec2
<
compute_t
>::
Type
;
using
dst_t
=
typename
TypeToVec2
<
weight_t
>::
Type
;
Vec
<
src_t
,
NUM_ELT
>
dbeta_vec2
,
dgamma_vec2
;
Vec
<
dst_t
,
NUM_ELT
>
dbeta_out2
,
dgamma_out2
;
dgamma_vec2
.
load_from
(
smem_gamma_out
,
lane
);
dbeta_vec2
.
load_from
(
smem_beta_out
,
lane
);
#pragma unroll
for
(
int
it
=
0
;
it
<
NUM_ELT
;
it
++
)
{
dgamma_out2
.
data
.
elt
[
it
]
=
Converter
<
src_t
,
dst_t
>::
convert
(
dgamma_vec2
.
data
.
elt
[
it
]);
dbeta_out2
.
data
.
elt
[
it
]
=
Converter
<
src_t
,
dst_t
>::
convert
(
dbeta_vec2
.
data
.
elt
[
it
]);
}
dgamma_out2
.
store_to
(
params
.
dgamma
,
col_out
);
dbeta_out2
.
store_to
(
params
.
dbeta
,
col_out
);
}
}
}
}
// namespace layer_norm
apex/contrib/csrc/layer_norm/ln_bwd_semi_cuda_kernel.cu
deleted
100644 → 0
View file @
2a4864d5
#include "ln.h"
#include "ln_utils.cuh"
#include "ln_kernel_traits.h"
#include "ln_bwd_kernels.cuh"
using
namespace
layer_norm
;
BwdRegistry
layer_norm
::
BWD_FUNCS
;
template
<
typename
weight_t
,
typename
input_t
,
typename
output_t
,
typename
compute_t
,
typename
index_t
,
int
HIDDEN_SIZE
,
int
CTAS_PER_ROW
,
int
WARPS_M
,
int
WARPS_N
,
int
BYTES_PER_LDG_MAIN
,
int
BYTES_PER_LDG_FINAL
>
void
launch_
(
LaunchParams
<
BwdParams
>
&
launch_params
,
const
bool
configure_params
){
using
Kernel_traits
=
Kernel_traits
<
weight_t
,
input_t
,
output_t
,
compute_t
,
index_t
,
HIDDEN_SIZE
,
CTAS_PER_ROW
,
WARPS_M
,
WARPS_N
,
BYTES_PER_LDG_MAIN
>
;
auto
kernel
=
&
ln_bwd_kernel
<
Kernel_traits
>
;
if
(
configure_params
)
{
int
ctas_per_sm
;
cudaError_t
status_
=
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
ctas_per_sm
,
kernel
,
Kernel_traits
::
THREADS_PER_CTA
,
Kernel_traits
::
SMEM_BYTES
);
launch_params
.
params
.
ctas_per_col
=
launch_params
.
props
->
multiProcessorCount
*
ctas_per_sm
/
Kernel_traits
::
CTAS_PER_ROW
;
launch_params
.
barrier_size
=
0
;
launch_params
.
workspace_bytes
=
0
;
if
(
Kernel_traits
::
CTAS_PER_ROW
>
1
)
{
launch_params
.
barrier_size
=
2
*
launch_params
.
params
.
ctas_per_col
;
launch_params
.
workspace_bytes
=
launch_params
.
params
.
ctas_per_col
*
Kernel_traits
::
WARPS_M
*
Kernel_traits
::
CTAS_PER_ROW
*
sizeof
(
typename
Kernel_traits
::
reduce_t
)
*
2
;
}
return
;
}
if
(
Kernel_traits
::
SMEM_BYTES
>=
48
*
1024
)
{
#if defined(__HIP_PLATFORM_HCC__)
CHECK_CUDA
(
hipFuncSetAttribute
((
const
void
*
)
kernel
,
hipFuncAttributeMaxDynamicSharedMemorySize
,
Kernel_traits
::
SMEM_BYTES
));
#else
CHECK_CUDA
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
Kernel_traits
::
SMEM_BYTES
));
#endif
}
auto
stream
=
launch_params
.
stream
;
auto
ctas_per_col
=
launch_params
.
params
.
ctas_per_col
;
if
(
Kernel_traits
::
CTAS_PER_ROW
==
1
)
{
kernel
<<<
ctas_per_col
,
Kernel_traits
::
THREADS_PER_CTA
,
Kernel_traits
::
SMEM_BYTES
,
stream
>>>
(
launch_params
.
params
);
}
else
{
dim3
grid
(
Kernel_traits
::
CTAS_PER_ROW
*
ctas_per_col
);
dim3
block
(
Kernel_traits
::
THREADS_PER_CTA
);
void
*
params_
=
(
void
*
)
&
launch_params
.
params
;
#if defined(__HIP_PLATFORM_HCC__)
hipLaunchCooperativeKernel
((
void
*
)
kernel
,
grid
,
block
,
(
void
**
)
&
params_
,
Kernel_traits
::
SMEM_BYTES
,
stream
);
#else
cudaLaunchCooperativeKernel
((
void
*
)
kernel
,
grid
,
block
,
(
void
**
)
&
params_
,
Kernel_traits
::
SMEM_BYTES
,
stream
);
#endif
}
using
Kernel_traits_f
=
layer_norm
::
Kernel_traits_finalize
<
HIDDEN_SIZE
,
weight_t
,
input_t
,
output_t
,
compute_t
,
index_t
,
32
*
32
,
// THREADS_PER_CTA
BYTES_PER_LDG_FINAL
>
;
auto
kernel_f
=
&
layer_norm
::
ln_bwd_finalize_kernel
<
Kernel_traits_f
>
;
kernel_f
<<<
Kernel_traits_f
::
CTAS
,
Kernel_traits_f
::
THREADS_PER_CTA
,
0
,
stream
>>>
(
launch_params
.
params
);
}
// Create backward launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
REGISTER_BWD_LAUNCHER
(
768
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
768
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
768
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
// REGISTER_BWD_LAUNCHER( 768, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);
// REGISTER_BWD_LAUNCHER( 768, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_LAUNCHER
(
1024
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
1024
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
1024
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
// REGISTER_BWD_LAUNCHER( 1024, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);
// REGISTER_BWD_LAUNCHER( 1024, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_LAUNCHER
(
1536
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
1536
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
4
,
8
,
4
);
REGISTER_BWD_LAUNCHER
(
1536
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
4
,
16
,
4
);
// REGISTER_BWD_LAUNCHER( 1536, bf16, bf16, bf16, fp32, 1, 1, 4, 8, 4);
// REGISTER_BWD_LAUNCHER( 1536, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER
(
2048
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
2048
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
2048
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
4
,
16
,
4
);
// REGISTER_BWD_LAUNCHER( 2048, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER( 2048, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER
(
2304
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
4
,
8
,
4
);
REGISTER_BWD_LAUNCHER
(
2304
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
4
,
4
,
4
);
REGISTER_BWD_LAUNCHER
(
2304
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
4
,
8
,
4
);
// REGISTER_BWD_LAUNCHER( 2304, bf16, bf16, bf16, fp32, 1, 1, 4, 4, 4);
// REGISTER_BWD_LAUNCHER( 2304, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4);
REGISTER_BWD_LAUNCHER
(
3072
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
3072
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
3072
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
4
,
16
,
4
);
// REGISTER_BWD_LAUNCHER( 3072, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER( 3072, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER
(
3840
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
4
,
8
,
4
);
REGISTER_BWD_LAUNCHER
(
3840
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
4
,
4
,
4
);
REGISTER_BWD_LAUNCHER
(
3840
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
4
,
8
,
4
);
// REGISTER_BWD_LAUNCHER( 3840, bf16, bf16, bf16, fp32, 1, 1, 4, 4, 4);
// REGISTER_BWD_LAUNCHER( 3840, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4);
REGISTER_BWD_LAUNCHER
(
4096
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
4096
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
4096
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
4
,
16
,
4
);
// REGISTER_BWD_LAUNCHER( 4096, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER( 4096, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER
(
5120
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
5120
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
5120
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
4
,
16
,
4
);
// REGISTER_BWD_LAUNCHER( 5120, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER( 5120, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER
(
6144
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
8
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
6144
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
8
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
6144
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
8
,
16
,
4
);
// REGISTER_BWD_LAUNCHER( 6144, bf16, bf16, bf16, fp32, 1, 1, 8, 16, 4);
// REGISTER_BWD_LAUNCHER( 6144, bf16, fp32, bf16, fp32, 1, 1, 8, 16, 4);
REGISTER_BWD_LAUNCHER
(
8192
,
fp32
,
fp32
,
fp32
,
fp32
,
2
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
8192
,
fp16
,
fp16
,
fp16
,
fp32
,
2
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
8192
,
fp16
,
fp32
,
fp16
,
fp32
,
2
,
1
,
4
,
16
,
4
);
// REGISTER_BWD_LAUNCHER( 8192, bf16, bf16, bf16, fp32, 2, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER( 8192, bf16, fp32, bf16, fp32, 2, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER
(
10240
,
fp32
,
fp32
,
fp32
,
fp32
,
2
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
10240
,
fp16
,
fp16
,
fp16
,
fp32
,
2
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
10240
,
fp16
,
fp32
,
fp16
,
fp32
,
2
,
1
,
4
,
16
,
4
);
// REGISTER_BWD_LAUNCHER(10240, bf16, bf16, bf16, fp32, 2, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER(10240, bf16, fp32, bf16, fp32, 2, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER
(
12288
,
fp32
,
fp32
,
fp32
,
fp32
,
4
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
12288
,
fp16
,
fp16
,
fp16
,
fp32
,
4
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
12288
,
fp16
,
fp32
,
fp16
,
fp32
,
4
,
1
,
4
,
16
,
4
);
// REGISTER_BWD_LAUNCHER(12288, bf16, bf16, bf16, fp32, 4, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER(12288, bf16, fp32, bf16, fp32, 4, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER
(
12800
,
fp32
,
fp32
,
fp32
,
fp32
,
5
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
12800
,
fp16
,
fp16
,
fp16
,
fp32
,
5
,
1
,
4
,
8
,
4
);
REGISTER_BWD_LAUNCHER
(
12800
,
fp16
,
fp32
,
fp16
,
fp32
,
5
,
1
,
4
,
16
,
4
);
// REGISTER_BWD_LAUNCHER(12800, bf16, bf16, bf16, fp32, 5, 1, 4, 8, 4);
// REGISTER_BWD_LAUNCHER(12800, bf16, fp32, bf16, fp32, 5, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER
(
14336
,
fp32
,
fp32
,
fp32
,
fp32
,
4
,
1
,
4
,
8
,
4
);
REGISTER_BWD_LAUNCHER
(
14336
,
fp16
,
fp16
,
fp16
,
fp32
,
4
,
1
,
4
,
8
,
4
);
REGISTER_BWD_LAUNCHER
(
14336
,
fp16
,
fp32
,
fp16
,
fp32
,
4
,
1
,
4
,
8
,
4
);
// REGISTER_BWD_LAUNCHER(14336, bf16, bf16, bf16, fp32, 4, 1, 4, 8, 4);
// REGISTER_BWD_LAUNCHER(14336, bf16, fp32, bf16, fp32, 4, 1, 4, 8, 4);
REGISTER_BWD_LAUNCHER
(
15360
,
fp32
,
fp32
,
fp32
,
fp32
,
4
,
1
,
4
,
8
,
4
);
REGISTER_BWD_LAUNCHER
(
15360
,
fp16
,
fp16
,
fp16
,
fp32
,
4
,
1
,
4
,
4
,
4
);
REGISTER_BWD_LAUNCHER
(
15360
,
fp16
,
fp32
,
fp16
,
fp32
,
4
,
1
,
4
,
8
,
4
);
// REGISTER_BWD_LAUNCHER(15360, bf16, bf16, bf16, fp32, 4, 1, 4, 4, 4);
// REGISTER_BWD_LAUNCHER(15360, bf16, fp32, bf16, fp32, 4, 1, 4, 8, 4);
REGISTER_BWD_LAUNCHER
(
16384
,
fp32
,
fp32
,
fp32
,
fp32
,
4
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
16384
,
fp16
,
fp16
,
fp16
,
fp32
,
4
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
16384
,
fp16
,
fp32
,
fp16
,
fp32
,
4
,
1
,
4
,
16
,
4
);
// REGISTER_BWD_LAUNCHER(16384, bf16, bf16, bf16, fp32, 4, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER(16384, bf16, fp32, bf16, fp32, 4, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER
(
18432
,
fp32
,
fp32
,
fp32
,
fp32
,
4
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
18432
,
fp16
,
fp16
,
fp16
,
fp32
,
4
,
1
,
4
,
8
,
4
);
REGISTER_BWD_LAUNCHER
(
18432
,
fp16
,
fp32
,
fp16
,
fp32
,
4
,
1
,
4
,
16
,
4
);
// REGISTER_BWD_LAUNCHER(18432, bf16, bf16, bf16, fp32, 4, 1, 4, 8, 4);
// REGISTER_BWD_LAUNCHER(18432, bf16, fp32, bf16, fp32, 4, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER
(
20480
,
fp32
,
fp32
,
fp32
,
fp32
,
4
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
20480
,
fp16
,
fp16
,
fp16
,
fp32
,
4
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
20480
,
fp16
,
fp32
,
fp16
,
fp32
,
4
,
1
,
4
,
16
,
4
);
// REGISTER_BWD_LAUNCHER(20480, bf16, bf16, bf16, fp32, 4, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER(20480, bf16, fp32, bf16, fp32, 4, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER
(
24576
,
fp32
,
fp32
,
fp32
,
fp32
,
4
,
1
,
8
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
24576
,
fp16
,
fp16
,
fp16
,
fp32
,
4
,
1
,
8
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
24576
,
fp16
,
fp32
,
fp16
,
fp32
,
4
,
1
,
8
,
16
,
4
);
// REGISTER_BWD_LAUNCHER(24576, bf16, bf16, bf16, fp32, 4, 1, 8, 16, 4);
// REGISTER_BWD_LAUNCHER(24576, bf16, fp32, bf16, fp32, 4, 1, 8, 16, 4);
REGISTER_BWD_LAUNCHER
(
25600
,
fp32
,
fp32
,
fp32
,
fp32
,
5
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
25600
,
fp16
,
fp16
,
fp16
,
fp32
,
5
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
25600
,
fp16
,
fp32
,
fp16
,
fp32
,
5
,
1
,
4
,
16
,
4
);
// REGISTER_BWD_LAUNCHER(25600, bf16, bf16, bf16, fp32, 5, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER(25600, bf16, fp32, bf16, fp32, 5, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER
(
30720
,
fp32
,
fp32
,
fp32
,
fp32
,
4
,
1
,
8
,
8
,
4
);
REGISTER_BWD_LAUNCHER
(
30720
,
fp16
,
fp16
,
fp16
,
fp32
,
4
,
1
,
8
,
4
,
4
);
REGISTER_BWD_LAUNCHER
(
30720
,
fp16
,
fp32
,
fp16
,
fp32
,
4
,
1
,
8
,
8
,
4
);
// REGISTER_BWD_LAUNCHER(30720, bf16, bf16, bf16, fp32, 4, 1, 8, 4, 4);
// REGISTER_BWD_LAUNCHER(30720, bf16, fp32, bf16, fp32, 4, 1, 8, 8, 4);
REGISTER_BWD_LAUNCHER
(
32768
,
fp32
,
fp32
,
fp32
,
fp32
,
4
,
1
,
8
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
32768
,
fp16
,
fp16
,
fp16
,
fp32
,
4
,
1
,
8
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
32768
,
fp16
,
fp32
,
fp16
,
fp32
,
4
,
1
,
8
,
16
,
4
);
// REGISTER_BWD_LAUNCHER(32768, bf16, bf16, bf16, fp32, 4, 1, 8, 16, 4);
// REGISTER_BWD_LAUNCHER(32768, bf16, fp32, bf16, fp32, 4, 1, 8, 16, 4);
REGISTER_BWD_LAUNCHER
(
40960
,
fp32
,
fp32
,
fp32
,
fp32
,
4
,
1
,
8
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
40960
,
fp16
,
fp16
,
fp16
,
fp32
,
4
,
1
,
8
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
40960
,
fp16
,
fp32
,
fp16
,
fp32
,
4
,
1
,
8
,
16
,
4
);
// REGISTER_BWD_LAUNCHER(40960, bf16, bf16, bf16, fp32, 4, 1, 8, 16, 4);
// REGISTER_BWD_LAUNCHER(40960, bf16, fp32, bf16, fp32, 4, 1, 8, 16, 4);
REGISTER_BWD_LAUNCHER
(
49152
,
fp32
,
fp32
,
fp32
,
fp32
,
8
,
1
,
8
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
49152
,
fp16
,
fp16
,
fp16
,
fp32
,
8
,
1
,
8
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
49152
,
fp16
,
fp32
,
fp16
,
fp32
,
8
,
1
,
8
,
16
,
4
);
// REGISTER_BWD_LAUNCHER(49152, bf16, bf16, bf16, fp32, 8, 1, 8, 16, 4);
// REGISTER_BWD_LAUNCHER(49152, bf16, fp32, bf16, fp32, 8, 1, 8, 16, 4);
REGISTER_BWD_LAUNCHER
(
65536
,
fp32
,
fp32
,
fp32
,
fp32
,
8
,
1
,
8
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
65536
,
fp16
,
fp16
,
fp16
,
fp32
,
8
,
1
,
8
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
65536
,
fp16
,
fp32
,
fp16
,
fp32
,
8
,
1
,
8
,
16
,
4
);
// REGISTER_BWD_LAUNCHER(65536, bf16, bf16, bf16, fp32, 8, 1, 8, 16, 4);
// REGISTER_BWD_LAUNCHER(65536, bf16, fp32, bf16, fp32, 8, 1, 8, 16, 4);
apex/contrib/csrc/layer_norm/ln_fwd_cuda_kernel.cu
deleted
100644 → 0
View file @
2a4864d5
#include "ln.h"
#include "ln_utils.cuh"
#include "ln_kernel_traits.h"
#include "ln_fwd_kernels.cuh"
using
namespace
layer_norm
;
FwdRegistry
layer_norm
::
FWD_FUNCS
;
template
<
typename
weight_t
,
typename
input_t
,
typename
output_t
,
typename
compute_t
,
typename
index_t
,
int
HIDDEN_SIZE
,
int
CTAS_PER_ROW
,
int
WARPS_M
,
int
WARPS_N
,
int
BYTES_PER_LDG
>
void
launch_
(
LaunchParams
<
FwdParams
>
&
launch_params
,
const
bool
configure_params
){
using
Kernel_traits
=
Kernel_traits
<
weight_t
,
input_t
,
output_t
,
compute_t
,
index_t
,
HIDDEN_SIZE
,
CTAS_PER_ROW
,
WARPS_M
,
WARPS_N
,
BYTES_PER_LDG
>
;
auto
kernel
=
&
ln_fwd_kernel
<
Kernel_traits
>
;
if
(
configure_params
)
{
int
ctas_per_sm
;
cudaError_t
status_
=
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
ctas_per_sm
,
kernel
,
Kernel_traits
::
THREADS_PER_CTA
,
Kernel_traits
::
SMEM_BYTES_FWD
);
launch_params
.
params
.
ctas_per_col
=
launch_params
.
props
->
multiProcessorCount
*
ctas_per_sm
/
Kernel_traits
::
CTAS_PER_ROW
;
launch_params
.
barrier_size
=
0
;
launch_params
.
workspace_bytes
=
0
;
if
(
Kernel_traits
::
CTAS_PER_ROW
>
1
)
{
launch_params
.
barrier_size
=
2
*
launch_params
.
params
.
ctas_per_col
;
launch_params
.
workspace_bytes
=
launch_params
.
params
.
ctas_per_col
*
Kernel_traits
::
WARPS_M
*
Kernel_traits
::
CTAS_PER_ROW
*
sizeof
(
typename
Kernel_traits
::
Stats
::
stats_t
)
*
2
;
}
return
;
}
if
(
Kernel_traits
::
SMEM_BYTES_FWD
>=
48
*
1024
)
{
#if defined(__HIP_PLATFORM_HCC__)
CHECK_CUDA
(
hipFuncSetAttribute
((
const
void
*
)
kernel
,
hipFuncAttributeMaxDynamicSharedMemorySize
,
Kernel_traits
::
SMEM_BYTES_FWD
));
#else
CHECK_CUDA
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
Kernel_traits
::
SMEM_BYTES_FWD
));
#endif
}
auto
stream
=
launch_params
.
stream
;
auto
ctas_per_col
=
launch_params
.
params
.
ctas_per_col
;
if
(
Kernel_traits
::
CTAS_PER_ROW
==
1
)
{
kernel
<<<
ctas_per_col
,
Kernel_traits
::
THREADS_PER_CTA
,
Kernel_traits
::
SMEM_BYTES_FWD
,
stream
>>>
(
launch_params
.
params
);
}
else
{
dim3
grid
(
Kernel_traits
::
CTAS_PER_ROW
*
ctas_per_col
);
dim3
block
(
Kernel_traits
::
THREADS_PER_CTA
);
void
*
params_
=
(
void
*
)
&
launch_params
.
params
;
#if defined(__HIP_PLATFORM_HCC__)
hipLaunchCooperativeKernel
((
void
*
)
kernel
,
grid
,
block
,
(
void
**
)
&
params_
,
Kernel_traits
::
SMEM_BYTES_FWD
,
stream
);
#else
cudaLaunchCooperativeKernel
((
void
*
)
kernel
,
grid
,
block
,
(
void
**
)
&
params_
,
Kernel_traits
::
SMEM_BYTES_FWD
,
stream
);
#endif
}
}
REGISTER_FWD_LAUNCHER
(
768
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
768
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
768
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
// REGISTER_FWD_LAUNCHER( 768, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
// REGISTER_FWD_LAUNCHER( 768, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER
(
1024
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
1024
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
1024
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
// REGISTER_FWD_LAUNCHER( 1024, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
// REGISTER_FWD_LAUNCHER( 1024, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER
(
1536
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
1536
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
1536
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
// REGISTER_FWD_LAUNCHER( 1536, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
// REGISTER_FWD_LAUNCHER( 1536, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER
(
2048
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
2048
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
2048
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
// REGISTER_FWD_LAUNCHER( 2048, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
// REGISTER_FWD_LAUNCHER( 2048, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER
(
2304
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
2304
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
2304
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
// REGISTER_FWD_LAUNCHER( 2304, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
// REGISTER_FWD_LAUNCHER( 2304, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER
(
3072
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
3072
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
3072
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
4
,
16
);
// REGISTER_FWD_LAUNCHER( 3072, bf16, bf16, bf16, fp32, 1, 1, 4, 16);
// REGISTER_FWD_LAUNCHER( 3072, bf16, fp32, bf16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER
(
3840
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
4
,
4
);
REGISTER_FWD_LAUNCHER
(
3840
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
4
,
4
);
REGISTER_FWD_LAUNCHER
(
3840
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
4
,
4
);
// REGISTER_FWD_LAUNCHER( 3840, bf16, bf16, bf16, fp32, 1, 1, 4, 4);
// REGISTER_FWD_LAUNCHER( 3840, bf16, fp32, bf16, fp32, 1, 1, 4, 4);
REGISTER_FWD_LAUNCHER
(
4096
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
4096
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
4096
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
4
,
16
);
// REGISTER_FWD_LAUNCHER( 4096, bf16, bf16, bf16, fp32, 1, 1, 4, 16);
// REGISTER_FWD_LAUNCHER( 4096, bf16, fp32, bf16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER
(
5120
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
5120
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
5120
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
4
,
16
);
// REGISTER_FWD_LAUNCHER( 5120, bf16, bf16, bf16, fp32, 1, 1, 4, 16);
// REGISTER_FWD_LAUNCHER( 5120, bf16, fp32, bf16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER
(
6144
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
6144
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
6144
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
4
,
16
);
// REGISTER_FWD_LAUNCHER( 6144, bf16, bf16, bf16, fp32, 1, 1, 4, 16);
// REGISTER_FWD_LAUNCHER( 6144, bf16, fp32, bf16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER
(
8192
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
8192
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
8192
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
4
,
16
);
// REGISTER_FWD_LAUNCHER( 8192, bf16, bf16, bf16, fp32, 1, 1, 4, 16);
// REGISTER_FWD_LAUNCHER( 8192, bf16, fp32, bf16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER
(
10240
,
fp32
,
fp32
,
fp32
,
fp32
,
2
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
10240
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
10240
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
4
,
16
);
// REGISTER_FWD_LAUNCHER(10240, bf16, bf16, bf16, fp32, 1, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(10240, bf16, fp32, bf16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER
(
12288
,
fp32
,
fp32
,
fp32
,
fp32
,
2
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
12288
,
fp16
,
fp16
,
fp16
,
fp32
,
2
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
12288
,
fp16
,
fp32
,
fp16
,
fp32
,
2
,
1
,
4
,
16
);
// REGISTER_FWD_LAUNCHER(12288, bf16, bf16, bf16, fp32, 2, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(12288, bf16, fp32, bf16, fp32, 2, 1, 4, 16);
REGISTER_FWD_LAUNCHER
(
12800
,
fp32
,
fp32
,
fp32
,
fp32
,
2
,
1
,
4
,
4
);
REGISTER_FWD_LAUNCHER
(
12800
,
fp16
,
fp16
,
fp16
,
fp32
,
2
,
1
,
4
,
4
);
REGISTER_FWD_LAUNCHER
(
12800
,
fp16
,
fp32
,
fp16
,
fp32
,
2
,
1
,
4
,
4
);
// REGISTER_FWD_LAUNCHER(12800, bf16, bf16, bf16, fp32, 2, 1, 4, 4);
// REGISTER_FWD_LAUNCHER(12800, bf16, fp32, bf16, fp32, 2, 1, 4, 4);
REGISTER_FWD_LAUNCHER
(
14336
,
fp32
,
fp32
,
fp32
,
fp32
,
2
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
14336
,
fp16
,
fp16
,
fp16
,
fp32
,
2
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
14336
,
fp16
,
fp32
,
fp16
,
fp32
,
2
,
1
,
4
,
16
);
// REGISTER_FWD_LAUNCHER(14336, bf16, bf16, bf16, fp32, 2, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(14336, bf16, fp32, bf16, fp32, 2, 1, 4, 8);
REGISTER_FWD_LAUNCHER
(
15360
,
fp32
,
fp32
,
fp32
,
fp32
,
2
,
1
,
4
,
8
);
REGISTER_FWD_LAUNCHER
(
15360
,
fp16
,
fp16
,
fp16
,
fp32
,
2
,
1
,
4
,
8
);
REGISTER_FWD_LAUNCHER
(
15360
,
fp16
,
fp32
,
fp16
,
fp32
,
2
,
1
,
4
,
8
);
// REGISTER_FWD_LAUNCHER(15360, bf16, bf16, bf16, fp32, 2, 1, 4, 8);
// REGISTER_FWD_LAUNCHER(15360, bf16, fp32, bf16, fp32, 2, 1, 4, 8);
REGISTER_FWD_LAUNCHER
(
16384
,
fp32
,
fp32
,
fp32
,
fp32
,
2
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
16384
,
fp16
,
fp16
,
fp16
,
fp32
,
2
,
1
,
4
,
16
);
// REGISTER_FWD_LAUNCHER(16384, bf16, bf16, bf16, fp32, 2, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(16384, bf16, fp32, bf16, fp32, 2, 1, 4, 16);
REGISTER_FWD_LAUNCHER
(
18432
,
fp32
,
fp32
,
fp32
,
fp32
,
4
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
18432
,
fp16
,
fp16
,
fp16
,
fp32
,
2
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
18432
,
fp16
,
fp32
,
fp16
,
fp32
,
2
,
1
,
4
,
16
);
// REGISTER_FWD_LAUNCHER(18432, bf16, bf16, bf16, fp32, 2, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(18432, bf16, fp32, bf16, fp32, 4, 1, 4, 16);
REGISTER_FWD_LAUNCHER
(
20480
,
fp32
,
fp32
,
fp32
,
fp32
,
2
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
20480
,
fp16
,
fp16
,
fp16
,
fp32
,
2
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
20480
,
fp16
,
fp32
,
fp16
,
fp32
,
2
,
1
,
4
,
16
);
// REGISTER_FWD_LAUNCHER(20480, bf16, bf16, bf16, fp32, 2, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(20480, bf16, fp32, bf16, fp32, 2, 1, 4, 16);
REGISTER_FWD_LAUNCHER
(
24576
,
fp32
,
fp32
,
fp32
,
fp32
,
4
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
24576
,
fp16
,
fp16
,
fp16
,
fp32
,
2
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
24576
,
fp16
,
fp32
,
fp16
,
fp32
,
2
,
1
,
4
,
16
);
// REGISTER_FWD_LAUNCHER(24576, bf16, bf16, bf16, fp32, 2, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(24576, bf16, fp32, bf16, fp32, 2, 1, 4, 16);
REGISTER_FWD_LAUNCHER
(
25600
,
fp32
,
fp32
,
fp32
,
fp32
,
4
,
1
,
4
,
4
);
REGISTER_FWD_LAUNCHER
(
25600
,
fp16
,
fp16
,
fp16
,
fp32
,
2
,
1
,
4
,
8
);
REGISTER_FWD_LAUNCHER
(
25600
,
fp16
,
fp32
,
fp16
,
fp32
,
4
,
1
,
4
,
4
);
// REGISTER_FWD_LAUNCHER(25600, bf16, bf16, bf16, fp32, 2, 1, 4, 8);
// REGISTER_FWD_LAUNCHER(25600, bf16, fp32, bf16, fp32, 4, 1, 4, 4);
REGISTER_FWD_LAUNCHER
(
30720
,
fp32
,
fp32
,
fp32
,
fp32
,
4
,
1
,
4
,
4
);
REGISTER_FWD_LAUNCHER
(
30720
,
fp16
,
fp16
,
fp16
,
fp32
,
4
,
1
,
4
,
4
);
REGISTER_FWD_LAUNCHER
(
30720
,
fp16
,
fp32
,
fp16
,
fp32
,
4
,
1
,
4
,
4
);
// REGISTER_FWD_LAUNCHER(30720, bf16, bf16, bf16, fp32, 4, 1, 4, 4);
// REGISTER_FWD_LAUNCHER(30720, bf16, fp32, bf16, fp32, 4, 1, 4, 4);
REGISTER_FWD_LAUNCHER
(
32768
,
fp32
,
fp32
,
fp32
,
fp32
,
4
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
32768
,
fp16
,
fp16
,
fp16
,
fp32
,
4
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
32768
,
fp16
,
fp32
,
fp16
,
fp32
,
4
,
1
,
4
,
16
);
// REGISTER_FWD_LAUNCHER(32768, bf16, bf16, bf16, fp32, 4, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(32768, bf16, fp32, bf16, fp32, 4, 1, 4, 16);
REGISTER_FWD_LAUNCHER
(
40960
,
fp32
,
fp32
,
fp32
,
fp32
,
4
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
40960
,
fp16
,
fp16
,
fp16
,
fp32
,
4
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
40960
,
fp16
,
fp32
,
fp16
,
fp32
,
4
,
1
,
4
,
16
);
// REGISTER_FWD_LAUNCHER(40960, bf16, bf16, bf16, fp32, 4, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(40960, bf16, fp32, bf16, fp32, 4, 1, 4, 16);
REGISTER_FWD_LAUNCHER
(
49152
,
fp32
,
fp32
,
fp32
,
fp32
,
8
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
49152
,
fp16
,
fp16
,
fp16
,
fp32
,
4
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
49152
,
fp16
,
fp32
,
fp16
,
fp32
,
4
,
1
,
4
,
16
);
// REGISTER_FWD_LAUNCHER(49152, bf16, bf16, bf16, fp32, 4, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(49152, bf16, fp32, bf16, fp32, 4, 1, 4, 16);
REGISTER_FWD_LAUNCHER
(
65536
,
fp32
,
fp32
,
fp32
,
fp32
,
8
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
65536
,
fp16
,
fp16
,
fp16
,
fp32
,
8
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
65536
,
fp16
,
fp32
,
fp16
,
fp32
,
8
,
1
,
4
,
16
);
// REGISTER_FWD_LAUNCHER(65536, bf16, bf16, bf16, fp32, 8, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(65536, bf16, fp32, bf16, fp32, 8, 1, 4, 16);
apex/contrib/csrc/layer_norm/ln_fwd_kernels.cuh
deleted
100644 → 0
View file @
2a4864d5
#pragma once
#if defined(__HIP_PLATFORM_HCC__)
#include "ln_utils.cuh"
#else
#include "ln.h"
#endif
namespace
layer_norm
{
template
<
typename
Ktraits
>
__global__
__launch_bounds__
(
Ktraits
::
THREADS_PER_CTA
)
void
ln_fwd_kernel
(
FwdParams
params
)
{
enum
{
ROWS_PER_CTA
=
Ktraits
::
ROWS_PER_CTA
};
enum
{
WARPS_N
=
Ktraits
::
WARPS_N
};
enum
{
WARPS_M
=
Ktraits
::
WARPS_M
};
enum
{
THREADS_PER_ROW
=
Ktraits
::
THREADS_PER_ROW
};
enum
{
VEC_COLS_PER_LDG
=
Ktraits
::
VEC_COLS_PER_LDG
};
enum
{
BYTES_PER_ROW
=
Ktraits
::
BYTES_PER_ROW
};
enum
{
LDGS
=
Ktraits
::
LDGS
};
enum
{
NUM_ELTS
=
Ktraits
::
NUM_ELTS
};
enum
{
CTAS_PER_ROW
=
Ktraits
::
CTAS_PER_ROW
};
using
output_t
=
typename
Ktraits
::
output_t
;
using
index_t
=
typename
Ktraits
::
index_t
;
using
compute_t
=
typename
Ktraits
::
compute_t
;
using
Ivec
=
typename
Ktraits
::
Ivec
;
using
Ovec
=
typename
Ktraits
::
Ovec
;
using
Wvec
=
typename
Ktraits
::
Wvec
;
using
Cvec
=
typename
Ktraits
::
Cvec
;
using
Stats
=
typename
Ktraits
::
Stats
;
using
stats_t
=
typename
Stats
::
stats_t
;
extern
__shared__
char
smem_
[];
const
index_t
tidx
=
threadIdx
.
x
;
const
index_t
bidn
=
blockIdx
.
x
%
CTAS_PER_ROW
;
const
index_t
bidm
=
blockIdx
.
x
/
CTAS_PER_ROW
;
const
index_t
lane
=
tidx
%
THREADS_PER_WARP
;
const
index_t
warp
=
tidx
/
THREADS_PER_WARP
;
const
index_t
warp_m
=
warp
/
WARPS_N
;
const
index_t
warp_n
=
warp
%
WARPS_N
;
const
index_t
r
=
bidm
*
ROWS_PER_CTA
+
warp_m
;
const
index_t
c
=
bidn
*
THREADS_PER_ROW
+
warp_n
*
THREADS_PER_WARP
+
lane
;
Stats
stats
(
params
,
bidm
,
bidn
,
warp_m
,
warp_n
,
lane
,
smem_
);
compute_t
*
mu_ptr
=
static_cast
<
compute_t
*>
(
params
.
mu
);
compute_t
*
rs_ptr
=
static_cast
<
compute_t
*>
(
params
.
rs
);
Wvec
gamma
[
LDGS
];
Wvec
beta
[
LDGS
];
index_t
idx
=
c
;
#pragma unroll
for
(
int
it
=
0
;
it
<
LDGS
;
it
++
)
{
gamma
[
it
].
load_from
(
params
.
gamma
,
idx
);
beta
[
it
].
load_from
(
params
.
beta
,
idx
);
idx
+=
VEC_COLS_PER_LDG
;
}
constexpr
compute_t
rn
=
1.
f
/
compute_t
(
Ktraits
::
COLS
);
for
(
int
row
=
r
;
row
<
params
.
rows
;
row
+=
params
.
ctas_per_col
*
ROWS_PER_CTA
)
{
Ivec
x
[
LDGS
];
index_t
idx
=
row
*
Ktraits
::
VEC_COLS
+
c
;
compute_t
xf
[
LDGS
*
NUM_ELTS
];
#pragma unroll
for
(
int
it
=
0
;
it
<
LDGS
;
it
++
)
{
x
[
it
].
load_from
(
params
.
x
,
idx
);
#pragma unroll
for
(
int
jt
=
0
;
jt
<
NUM_ELTS
;
jt
++
)
{
compute_t
x_ij
=
compute_t
(
x
[
it
].
data
.
elt
[
jt
]);
xf
[
it
*
NUM_ELTS
+
jt
]
=
x_ij
;
}
idx
+=
VEC_COLS_PER_LDG
;
}
stats_t
s
=
stats
.
compute
(
xf
,
rn
);
compute_t
mu
=
layer_norm
::
Get
<
0
>::
of
<
stats_t
,
compute_t
>
(
s
);
compute_t
m2
=
layer_norm
::
Get
<
1
>::
of
<
stats_t
,
compute_t
>
(
s
);
if
(
bidn
==
0
&&
warp_n
==
0
&&
lane
==
0
)
{
mu_ptr
[
row
]
=
mu
;
}
compute_t
rs
=
rsqrtf
(
rn
*
m2
+
params
.
epsilon
);
if
(
bidn
==
0
&&
warp_n
==
0
&&
lane
==
0
)
{
rs_ptr
[
row
]
=
rs
;
}
Ovec
z
[
LDGS
];
idx
=
row
*
Ktraits
::
VEC_COLS
+
c
;
#pragma unroll
for
(
int
it
=
0
;
it
<
LDGS
;
it
++
)
{
#pragma unroll
for
(
int
jt
=
0
;
jt
<
NUM_ELTS
;
jt
++
)
{
output_t
y_ij
=
output_t
(
rs
*
(
xf
[
it
*
NUM_ELTS
+
jt
]
-
mu
));
output_t
g_ij
=
gamma
[
it
].
data
.
elt
[
jt
];
output_t
b_ij
=
beta
[
it
].
data
.
elt
[
jt
];
z
[
it
].
data
.
elt
[
jt
]
=
(
g_ij
*
y_ij
+
b_ij
);
}
z
[
it
].
store_to
(
params
.
z
,
idx
);
idx
+=
VEC_COLS_PER_LDG
;
}
}
}
}
// namespace layer_norm
apex/contrib/csrc/layer_norm/ln_kernel_traits.h
deleted
100644 → 0
View file @
2a4864d5
#pragma once
////////////////////////////////////////////////////////////////////////////////////////////////////
namespace
layer_norm
{
template
<
uint32_t
HIDDEN_SIZE_
,
typename
weight_t_
,
typename
input_t_
,
typename
output_t_
,
typename
compute_t_
,
typename
index_t_
,
uint32_t
THREADS_PER_CTA_
>
struct
Kernel_traits_base
{
using
weight_t
=
weight_t_
;
using
input_t
=
input_t_
;
using
output_t
=
output_t_
;
using
compute_t
=
compute_t_
;
using
index_t
=
index_t_
;
enum
{
HIDDEN_SIZE
=
HIDDEN_SIZE_
};
enum
{
THREADS_PER_CTA
=
THREADS_PER_CTA_
};
enum
{
THREADS_PER_WARP
=
32
};
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
uint32_t
HIDDEN_SIZE_
,
typename
weight_t_
,
typename
input_t_
,
typename
output_t_
,
typename
compute_t_
,
typename
index_t_
,
uint32_t
THREADS_PER_CTA_
,
uint32_t
BYTES_PER_LDG_
,
typename
Base
=
Kernel_traits_base
<
HIDDEN_SIZE_
,
weight_t_
,
input_t_
,
output_t_
,
compute_t_
,
index_t_
,
THREADS_PER_CTA_
>
>
struct
Kernel_traits_finalize
:
public
Base
{
enum
{
ROWS_PER_CTA
=
Base
::
THREADS_PER_CTA
/
Base
::
THREADS_PER_WARP
};
static_assert
((
int
)
ROWS_PER_CTA
<=
(
int
)
Base
::
THREADS_PER_WARP
);
// Bytes per global load from the input.
enum
{
BYTES_PER_LDG
=
BYTES_PER_LDG_
};
// Number of elements fetched by a global load.
enum
{
ELTS_PER_LDG
=
BYTES_PER_LDG
/
sizeof
(
compute_t_
)
};
// Bytes per global store of the weights.
enum
{
BYTES_PER_STG
=
ELTS_PER_LDG
*
sizeof
(
weight_t_
)
};
static_assert
(
sizeof
(
BYTES_PER_LDG
)
==
4
,
"Conflict-free smem transpose only implemented for 4B compute type!"
);
static_assert
(
Base
::
THREADS_PER_CTA
==
ROWS_PER_CTA
*
Base
::
THREADS_PER_WARP
,
"We assume one warp per row!"
);
// The total number of BYTES_PER_LDG-wide words in a hidden vector.
enum
{
COLS
=
HIDDEN_SIZE_
*
sizeof
(
compute_t_
)
/
BYTES_PER_LDG
};
static_assert
(
COLS
*
BYTES_PER_LDG
==
HIDDEN_SIZE_
*
sizeof
(
compute_t_
));
// Shared memory size to transpose the CTA result.
enum
{
SMEM_BYTES_TRANSPOSE
=
Base
::
THREADS_PER_CTA
*
BYTES_PER_LDG
};
// Shared memory size to coalsece the CTA result.
enum
{
SMEM_BYTES_OUTPUT
=
Base
::
THREADS_PER_WARP
*
BYTES_PER_LDG
};
// Shared memory requirement per CTA.
enum
{
SMEM_BYTES_PER_CTA
=
2
*
SMEM_BYTES_TRANSPOSE
+
2
*
SMEM_BYTES_OUTPUT
};
// The type of the reducer.
using
Reducer
=
layer_norm
::
Reducer
<
compute_t_
,
1
,
1
,
1
>
;
// Condition for the whole CTA to participate in syncthreads.
static_assert
(
COLS
%
Base
::
THREADS_PER_WARP
==
0
);
enum
{
CTAS
=
COLS
/
Base
::
THREADS_PER_WARP
};
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
weight_t_
,
typename
input_t_
,
typename
output_t_
,
typename
compute_t_
,
typename
index_t_
,
uint32_t
HIDDEN_SIZE_
,
uint32_t
CTAS_PER_ROW_
,
uint32_t
WARPS_M_
,
uint32_t
WARPS_N_
,
uint32_t
BYTES_PER_LDG_
=
16
,
typename
Base
=
Kernel_traits_base
<
HIDDEN_SIZE_
,
weight_t_
,
input_t_
,
output_t_
,
compute_t_
,
index_t_
,
WARPS_M_
*
WARPS_N_
*
THREADS_PER_WARP
>
>
struct
Kernel_traits
:
public
Base
{
using
input_t
=
typename
Base
::
input_t
;
using
weight_t
=
typename
Base
::
weight_t
;
using
compute_t
=
typename
Base
::
compute_t
;
using
output_t
=
typename
Base
::
output_t
;
using
index_t
=
typename
Base
::
index_t
;
enum
{
CTAS_PER_ROW
=
CTAS_PER_ROW_
};
enum
{
WARPS_M
=
WARPS_M_
};
enum
{
WARPS_N
=
WARPS_N_
};
enum
{
COLS
=
HIDDEN_SIZE_
};
enum
{
HIDDEN_SIZE
=
HIDDEN_SIZE_
};
enum
{
BYTES_PER_LDG
=
BYTES_PER_LDG_
};
enum
{
NUM_ELTS
=
BYTES_PER_LDG
/
sizeof
(
input_t
)
};
enum
{
THREADS_PER_ROW
=
WARPS_N
*
THREADS_PER_WARP
};
enum
{
THREADS_PER_CTA
=
WARPS_M
*
THREADS_PER_ROW
};
enum
{
ROWS_PER_CTA
=
WARPS_M
};
enum
{
BYTES_PER_ROW
=
COLS
*
sizeof
(
input_t
)
};
enum
{
BYTES_PER_ROW_PER_CTA
=
THREADS_PER_ROW
*
BYTES_PER_LDG
};
// Multi-row per CTA not supported for multi-CTA => no smem for WGRAD needed
enum
{
SMEM_BYTES_WGRAD
=
CTAS_PER_ROW
>
1
?
0
:
ROWS_PER_CTA
*
COLS
*
sizeof
(
compute_t
)
};
static_assert
(
WARPS_M
==
1
||
CTAS_PER_ROW
==
1
);
using
reduce_t
=
typename
layer_norm
::
TypeToVec2
<
compute_t
>::
Type
;
using
Reducer
=
layer_norm
::
Reducer
<
reduce_t
,
CTAS_PER_ROW
,
WARPS_M
,
WARPS_N
>
;
enum
{
SMEM_BYTES_DGRAD
=
Reducer
::
SMEM_BYTES
};
enum
{
SMEM_BYTES
=
SMEM_BYTES_DGRAD
+
SMEM_BYTES_WGRAD
};
using
Ivec
=
layer_norm
::
Vec
<
input_t
,
NUM_ELTS
>
;
using
Ovec
=
layer_norm
::
Vec
<
output_t
,
NUM_ELTS
>
;
using
Wvec
=
layer_norm
::
Vec
<
weight_t
,
NUM_ELTS
>
;
using
Cvec
=
layer_norm
::
Vec
<
compute_t
,
NUM_ELTS
>
;
enum
{
ELTS_PER_LDG
=
BYTES_PER_LDG
/
sizeof
(
input_t
)
};
// Assume that each thread can handle the same number of elements in the output and weights as in the input.
static_assert
(
sizeof
(
input_t
)
>=
sizeof
(
output_t
));
static_assert
(
sizeof
(
input_t
)
>=
sizeof
(
weight_t
));
// The number of columns fetched per load from input: one per thread.
enum
{
VEC_COLS_PER_LDG
=
CTAS_PER_ROW
*
THREADS_PER_ROW
};
// The total number of vectorized loads/stores per hidden vector.
enum
{
VEC_COLS
=
COLS
/
ELTS_PER_LDG
};
// The number of loads per thread for the input.
enum
{
LDGS
=
VEC_COLS
/
VEC_COLS_PER_LDG
};
static_assert
(
LDGS
*
VEC_COLS_PER_LDG
==
VEC_COLS
);
//static_assert(LDGS * BYTES_PER_ROW_PER_CTA * CTAS_PER_ROW == BYTES_PER_ROW, "");
using
Stats
=
layer_norm
::
Stats
<
compute_t
,
CTAS_PER_ROW
,
WARPS_M
,
WARPS_N
>
;
enum
{
SMEM_BYTES_FWD
=
Stats
::
SMEM_BYTES
};
};
////////////////////////////////////////////////////////////////////////////////////////////////////
}
// namespace layer_norm
apex/contrib/csrc/layer_norm/ln_utils.cuh
deleted
100644 → 0
View file @
2a4864d5
#pragma once
#include <cassert>
#if defined(__HIP_PLATFORM_HCC__)
#include "hip/hip_fp16.h"
#include "hip/hip_bfloat16.h"
#else
#include <cuda_fp16.h>
#include <cuda_bf16.h>
#endif
#include "ln.h"
////////////////////////////////////////////////////////////////////////////////////////////////////
constexpr
uint32_t
THREADS_PER_WARP
=
32
;
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
void
check_cuda_
(
cudaError_t
status
,
const
char
*
file
,
int
line
)
{
if
(
status
!=
cudaSuccess
)
{
fprintf
(
stderr
,
"CUDA Error: %s %s %d
\n
"
,
cudaGetErrorString
(
status
),
file
,
line
);
exit
(
status
);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
#define CHECK_CUDA(ans) \
{ check_cuda_((ans), __FILE__, __LINE__); }
////////////////////////////////////////////////////////////////////////////////////////////////////
#define DIVUP(x, y) (((x) + ((y)-1)) / (y))
////////////////////////////////////////////////////////////////////////////////////////////////////
#define REGISTER_FWD_LAUNCHER(HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG) \
void ln_fwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE(LaunchParams<FwdParams> &launch_params, \
const bool configure_params) { \
launch_<WTYPE, ITYPE, OTYPE, CTYPE, uint32_t, HIDDEN_SIZE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG>( \
launch_params, configure_params); \
} \
static FwdRegistrar<WTYPE, ITYPE, OTYPE, CTYPE, HIDDEN_SIZE> reg_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \
ln_fwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE)
////////////////////////////////////////////////////////////////////////////////////////////////////
#define REGISTER_BWD_LAUNCHER( \
HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINALIZE) \
void ln_bwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE(LaunchParams<BwdParams> &launch_params, \
const bool configure_params) { \
launch_<WTYPE, \
ITYPE, \
OTYPE, \
CTYPE, \
uint32_t, \
HIDDEN_SIZE, \
CTAS_PER_ROW, \
WARPS_M, \
WARPS_N, \
BYTES_PER_LDG, \
BYTES_PER_LDG_FINALIZE>(launch_params, configure_params); \
} \
static BwdRegistrar<WTYPE, ITYPE, OTYPE, CTYPE, HIDDEN_SIZE> reg_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \
ln_bwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE)
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
float2
operator
+
(
const
float2
&
a
,
const
float2
&
b
){
return
{
a
.
x
+
b
.
x
,
a
.
y
+
b
.
y
};
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
void
operator
+=
(
float2
&
a
,
const
float2
&
b
){
a
.
x
+=
b
.
x
;
a
.
y
+=
b
.
y
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
T
>
struct
Sum
{
inline
__device__
Sum
(){}
inline
__device__
T
operator
()(
const
T
&
a
,
const
T
&
b
){
return
a
+
b
;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
T
>
inline
__device__
T
warp_shuffle_xor
(
const
T
&
x
,
uint32_t
idx
){
#if defined(__HIP_PLATFORM_HCC__)
return
__shfl_xor
(
x
,
idx
);
#else
return
__shfl_xor_sync
(
uint32_t
(
-
1
),
x
,
idx
);
#endif
}
template
<
>
inline
__device__
float2
warp_shuffle_xor
<
float2
>
(
const
float2
&
x
,
uint32_t
idx
){
return
{
warp_shuffle_xor
(
x
.
x
,
idx
),
warp_shuffle_xor
(
x
.
y
,
idx
)
};
}
template
<
typename
T
>
inline
__device__
T
warp_shuffle_down
(
const
T
&
x
,
uint32_t
idx
){
#if defined(__HIP_PLATFORM_HCC__)
return
__shfl_down
(
x
,
idx
);
#else
return
__shfl_down_sync
(
uint32_t
(
-
1
),
x
,
idx
);
#endif
}
template
<
>
inline
__device__
float2
warp_shuffle_down
<
float2
>
(
const
float2
&
x
,
uint32_t
idx
){
return
{
warp_shuffle_down
(
x
.
x
,
idx
),
warp_shuffle_down
(
x
.
y
,
idx
)
};
}
////////////////////////////////////////////////////////////////////////////////////////////////////
namespace
layer_norm
{
////////////////////////////////////////////////////////////////////////////////////////////////////
struct
uint16
{
uint4
u
;
uint4
v
;
uint4
s
;
uint4
t
;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
struct
uint8
{
uint4
u
;
uint4
v
;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
BYTES
>
struct
BytesToType
{};
template
<
>
struct
BytesToType
<
64
>
{
using
Type
=
uint16
;
static_assert
(
sizeof
(
Type
)
==
64
);
};
template
<
>
struct
BytesToType
<
32
>
{
using
Type
=
uint8
;
static_assert
(
sizeof
(
Type
)
==
32
);
};
template
<
>
struct
BytesToType
<
16
>
{
using
Type
=
uint4
;
static_assert
(
sizeof
(
Type
)
==
16
);
};
template
<
>
struct
BytesToType
<
8
>
{
using
Type
=
uint64_t
;
static_assert
(
sizeof
(
Type
)
==
8
);
};
template
<
>
struct
BytesToType
<
4
>
{
using
Type
=
uint32_t
;
static_assert
(
sizeof
(
Type
)
==
4
);
};
template
<
>
struct
BytesToType
<
2
>
{
using
Type
=
uint16_t
;
static_assert
(
sizeof
(
Type
)
==
2
);
};
template
<
>
struct
BytesToType
<
1
>
{
using
Type
=
uint8_t
;
static_assert
(
sizeof
(
Type
)
==
1
);
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
T
>
struct
TypeToVec2
{};
template
<
>
struct
TypeToVec2
<
float
>
{
using
Type
=
float2
;
};
template
<
>
struct
TypeToVec2
<
half
>
{
using
Type
=
half2
;
};
#if 0
template<>
struct TypeToVec2<nv_bfloat16> {
using Type = nv_bfloat162;
};
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
INDEX
>
struct
Get
{
template
<
typename
T
,
typename
R
>
static
inline
__device__
R
of
(
const
T
&
vec
);
};
template
<
>
template
<
typename
T
,
typename
R
>
inline
__device__
R
Get
<
0
>::
of
(
const
T
&
vec
)
{
return
vec
.
x
;
}
template
<
>
template
<
typename
T
,
typename
R
>
inline
__device__
R
Get
<
1
>::
of
(
const
T
&
vec
)
{
return
vec
.
y
;
}
template
<
>
template
<
typename
T
,
typename
R
>
inline
__device__
R
Get
<
2
>::
of
(
const
T
&
vec
)
{
return
vec
.
z
;
}
template
<
>
template
<
typename
T
,
typename
R
>
inline
__device__
R
Get
<
3
>::
of
(
const
T
&
vec
)
{
return
vec
.
w
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Src
,
typename
Dst
>
struct
Converter
{
static
inline
__device__
Dst
convert
(
const
Src
&
from
)
{
return
Dst
(
from
);
}
};
template
<
>
struct
Converter
<
float2
,
half2
>
{
static
inline
__device__
half2
convert
(
const
float2
&
x
)
{
return
__float22half2_rn
(
x
);
}
};
#if defined(__HIP_PLATFORM_HCC__)
template
<
>
struct
Converter
<
float
,
half
>
{
static
inline
__device__
half
convert
(
const
float
&
x
)
{
return
__float2half
(
x
);
}
};
template
<
>
struct
Converter
<
half
,
float
>
{
static
inline
__device__
float
convert
(
const
half
&
x
)
{
return
__half2float
(
x
);
}
};
template
<
>
struct
Converter
<
float
,
hip_bfloat16
>
{
static
inline
__device__
hip_bfloat16
convert
(
const
float
&
x
)
{
return
hip_bfloat16
::
round_to_bfloat16
(
x
);
}
};
template
<
>
struct
Converter
<
hip_bfloat16
,
float
>
{
static
inline
__device__
float
convert
(
const
hip_bfloat16
&
x
)
{
return
float
(
x
);
}
};
#endif
#if 0
template<>
struct Converter<float2, nv_bfloat162>{
static inline __device__ nv_bfloat162 convert(const float2 &x) {
#if __CUDA_ARCH__ >= 800
return __float22bfloat162_rn(x);
#else
union {
nv_bfloat162 raw;
nv_bfloat16 x;
nv_bfloat16 y;
} tmp;
tmp.x = __float2bfloat16_rn(x.x);
tmp.y = __float2bfloat16_rn(x.y);
return tmp.raw;
#endif
}
}
;
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
T
>
struct
Zeros
{
static
inline
__device__
T
get
()
{
return
T
(
0.
f
);
}
};
template
<
>
struct
Zeros
<
float2
>
{
static
inline
__device__
float2
get
()
{
return
make_float2
(
0.
f
,
0.
f
);
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Elt_type
,
uint32_t
NUM_ELT
>
struct
Vec
{
enum
{
BYTES
=
NUM_ELT
*
sizeof
(
Elt_type
)
};
using
Vec_type
=
typename
BytesToType
<
BYTES
>::
Type
;
using
Alias_type
=
union
{
Vec_type
vec
;
Elt_type
elt
[
NUM_ELT
];
};
Alias_type
data
;
template
<
typename
S
>
inline
__device__
void
to
(
Vec
<
S
,
NUM_ELT
>
&
other
)
{
#pragma unroll
for
(
int
it
=
0
;
it
<
NUM_ELT
;
it
++
)
{
other
.
data
.
elt
[
it
]
=
S
(
this
->
data
.
elt
[
it
]);
}
}
template
<
typename
Op
>
inline
__device__
void
assign
(
const
Op
&
op
)
{
#pragma unroll
for
(
int
it
=
0
;
it
<
NUM_ELT
;
it
++
)
{
this
->
data
.
elt
[
it
]
=
op
(
it
);
}
}
inline
__device__
void
load_from
(
const
void
*
base_ptr
,
const
size_t
idx
)
{
this
->
data
.
vec
=
static_cast
<
const
Vec_type
*>
(
base_ptr
)[
idx
];
}
inline
__device__
void
store_to
(
void
*
base_ptr
,
const
size_t
idx
)
{
static_cast
<
Vec_type
*>
(
base_ptr
)[
idx
]
=
this
->
data
.
vec
;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
uint32_t
CTAS_PER_ROW
>
struct
InterCTASync
{
template
<
typename
Params
>
inline
__device__
InterCTASync
(
Params
&
params
,
uint32_t
bidm
,
uint32_t
bidn
)
:
phase_counter_
(
0
)
,
b0_
(
params
.
barrier
+
bidm
)
// The barrier for this group of CTAs.
,
b1_
(
params
.
barrier
+
bidm
+
params
.
ctas_per_col
)
// The barrier for this group of CTAs.
{
// BARRIERS ARE ASSUMED TO BE INITIALIZED TO 0!
}
inline
__device__
void
spin_wait_
(
int
*
barrier
,
int
step
,
int
expected
)
{
#if defined(__HIP_PLATFORM_HCC__)
atomicAdd
(
barrier
,
step
);
for
(
int
found
=
-
1
;
found
!=
expected
;
)
{
// asm volatile("global_load_dword %0, %1, off;" : "=v"(found) : "v"(barrier));
found
=
atomicCAS
(
barrier
,
expected
,
expected
);
}
#else
asm
volatile
(
"red.release.gpu.global.add.s32 [%0], %1;"
::
"l"
(
barrier
),
"r"
(
step
));
for
(
int
found
=
-
1
;
found
!=
expected
;
)
{
asm
volatile
(
"ld.global.acquire.gpu.b32 %0, [%1];"
:
"=r"
(
found
)
:
"l"
(
barrier
));
}
#endif
}
inline
__device__
void
sync
(){
// ALL THREADS MUST ENTER!
// We switch barrier every iteration.
int
*
barrier
=
phase_counter_
&
0x1
?
b1_
:
b0_
;
// We decrement every other iteration.
bool
dec
=
phase_counter_
&
0x2
;
int
step
=
dec
?
-
1
:
1
;
int
expected
=
dec
?
0
:
CTAS_PER_ROW
;
// There are only 4 phases: up/down for b0/b1.
phase_counter_
=
(
phase_counter_
+
1
)
&
0x3
;
if
(
threadIdx
.
x
==
0
)
{
spin_wait_
(
barrier
,
step
,
expected
);
}
// CTA waits for thread 0
__syncthreads
();
}
int
phase_counter_
;
int
*
b0_
;
int
*
b1_
;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
T
,
uint32_t
CTAS_PER_ROW
,
uint32_t
WARPS_M
,
uint32_t
WARPS_N
>
struct
Reducer
:
public
Reducer
<
T
,
1
,
WARPS_M
,
WARPS_N
>
{
using
InterCTASync
=
InterCTASync
<
CTAS_PER_ROW
>
;
using
Base
=
Reducer
<
T
,
1
,
WARPS_M
,
WARPS_N
>
;
using
Type
=
typename
Base
::
Type
;
enum
{
SMEM_BYTES
=
Base
::
SMEM_BYTES
};
enum
{
WS_BARRIER_BYTES
=
2
*
sizeof
(
int
)
};
enum
{
WS_DATA_BYTES
=
WARPS_M
*
CTAS_PER_ROW
*
sizeof
(
T
)
};
// size of the barriers + temporary result per CTA (multiply with CTAS_PER_ROW to get total)
enum
{
WORKSPACE_BYTES_PER_GROUP
=
Base
::
WORKSPACE_BYTES_PER_GROUP
+
WS_BARRIER_BYTES
+
WS_DATA_BYTES
};
template
<
typename
Params
>
inline
__device__
Reducer
(
Params
&
params
,
uint32_t
bidm
,
uint32_t
bidn
,
uint32_t
warp_m
,
uint32_t
warp_n
,
uint32_t
lane
,
void
*
smem
)
:
Base
(
params
,
bidm
,
bidn
,
warp_m
,
warp_n
,
lane
,
smem
)
,
inter_cta_
(
params
,
bidm
,
bidn
)
,
bidn_
(
bidn
)
// CTA id within the group.
,
w0_
(
static_cast
<
T
*>
(
params
.
workspace
)
+
(
bidm
*
WARPS_M
+
warp_m
)
*
CTAS_PER_ROW
)
,
w1_
(
w0_
+
params
.
ctas_per_col
*
WARPS_M
*
CTAS_PER_ROW
)
{
}
template
<
typename
Op
>
inline
__device__
T
allreduce
(
T
data
,
Op
&
op
)
{
data
=
Base
::
reduce
(
data
,
op
);
// We switch workspace every iteration.
T
*
workspace
=
inter_cta_
.
phase_counter_
&
0x1
?
w1_
:
w0_
;
// Warp leaders 0 hold the CTA-local results.
if
(
this
->
warp_n_
==
0
&&
this
->
lane_
==
0
)
{
workspace
[
bidn_
]
=
data
;
}
inter_cta_
.
sync
();
static_assert
(
CTAS_PER_ROW
<=
32
);
T
total
=
Zeros
<
T
>::
get
();
if
(
this
->
lane_
<
CTAS_PER_ROW
){
total
=
workspace
[
this
->
lane_
];
}
total
=
Reducer
<
T
,
1
,
1
,
1
>::
allreduce_
(
total
,
op
);
return
total
;
}
InterCTASync
inter_cta_
;
T
*
w0_
;
T
*
w1_
;
int
bidn_
;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
T
,
uint32_t
WARPS_M
>
struct
Reducer
<
T
,
1
,
WARPS_M
,
1
>
{
using
Type
=
T
;
enum
{
SMEM_BYTES
=
0
};
enum
{
WORKSPACE_BYTES_PER_GROUP
=
0
};
enum
{
THREADS_PER_WARP
=
32
};
template
<
typename
Params
>
inline
__device__
Reducer
(
Params
&
params
,
uint32_t
bidm
,
uint32_t
bidn
,
uint32_t
warp_m
,
uint32_t
warp_n
,
uint32_t
lane
,
void
*
smem
)
:
warp_n_
(
warp_n
)
,
lane_
(
lane
)
{
}
template
<
typename
Op
>
static
inline
__device__
T
allreduce_
(
T
data
,
Op
&
op
)
{
#pragma unroll
for
(
int
it
=
1
;
it
<
THREADS_PER_WARP
;
it
*=
2
)
{
data
=
op
(
data
,
warp_shuffle_xor
(
data
,
it
));
}
return
data
;
}
template
<
typename
Op
>
inline
__device__
T
allreduce
(
T
data
,
Op
&
op
)
{
return
allreduce_
(
data
,
op
);
}
template
<
typename
Op
>
inline
__device__
T
reduce
(
T
data
,
Op
&
op
){
// only lane 0 holds the result!
#pragma unroll
for
(
int
it
=
THREADS_PER_WARP
/
2
;
it
>
0
;
it
/=
2
)
{
data
=
op
(
data
,
warp_shuffle_down
(
data
,
it
));
}
return
data
;
}
int
warp_n_
;
int
lane_
;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
T
,
uint32_t
WARPS_M
,
uint32_t
WARPS_N
>
struct
Reducer
<
T
,
1
,
WARPS_M
,
WARPS_N
>
:
public
Reducer
<
T
,
1
,
WARPS_M
,
1
>
{
using
Base
=
Reducer
<
T
,
1
,
WARPS_M
,
1
>
;
using
Type
=
T
;
enum
{
SMEM_BYTES
=
Base
::
SMEM_BYTES
+
WARPS_M
*
WARPS_N
*
sizeof
(
T
)
*
2
};
enum
{
WORKSPACE_BYTES_PER_GROUP
=
0
};
enum
{
THREADS_PER_WARP
=
32
};
template
<
typename
Params
>
inline
__device__
Reducer
(
Params
&
params
,
uint32_t
bidm
,
uint32_t
bidn
,
uint32_t
warp_m
,
uint32_t
warp_n
,
uint32_t
lane
,
void
*
smem
)
:
Base
(
params
,
bidm
,
bidn
,
warp_m
,
warp_n
,
lane
,
smem
)
,
use0_
(
true
)
{
smem0_
=
&
static_cast
<
T
*>
(
smem
)[
warp_m
*
WARPS_N
];
smem1_
=
smem0_
+
WARPS_M
*
WARPS_N
;
}
template
<
typename
Op
>
inline
__device__
T
allreduce
(
T
data
,
Op
&
op
)
{
T
*
smem
=
use0_
?
smem0_
:
smem1_
;
use0_
=
!
use0_
;
data
=
Base
::
reduce
(
data
,
op
);
if
(
this
->
lane_
==
0
)
{
smem
[
this
->
warp_n_
]
=
data
;
}
__syncthreads
();
T
out
=
Zeros
<
T
>::
get
();
#pragma unroll
for
(
int
it
=
0
;
it
<
WARPS_N
;
it
++
)
{
out
=
op
(
out
,
smem
[
it
]);
}
return
out
;
}
template
<
typename
Op
>
inline
__device__
T
reduce
(
T
data
,
Op
&
op
)
{
T
*
smem
=
use0_
?
smem0_
:
smem1_
;
use0_
=
!
use0_
;
// only intra-CTA group leader holds the result!
data
=
Base
::
reduce
(
data
,
op
);
if
(
this
->
lane_
==
0
)
{
smem
[
this
->
warp_n_
]
=
data
;
}
__syncthreads
();
T
out
=
Zeros
<
T
>::
get
();
if
(
this
->
warp_n_
==
0
&&
this
->
lane_
==
0
)
{
#pragma unroll
for
(
int
it
=
0
;
it
<
WARPS_N
;
it
++
)
{
out
=
op
(
out
,
smem
[
it
]);
}
}
return
out
;
}
T
*
smem0_
;
T
*
smem1_
;
bool
use0_
;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
T
>
inline
__device__
void
warp_chan_upd_dynamic
(
T
&
m_a
,
T
&
m2_a
,
T
&
n_a
,
int
num_active
){
//Assume at least leftmost is valid and init: step = next_pow2(num_active) / 2 (might get NaN otherwise)
int
highest_bit_set
=
(
8
*
sizeof
(
num_active
))
-
__clz
(
num_active
-
1
);
#pragma unroll
for
(
int
step
=
(
1
<<
(
highest_bit_set
-
1
));
step
>
0
;
step
/=
2
)
{
// Exchange
T
n_b
=
warp_shuffle_down
(
n_a
,
step
);
T
m_b
=
warp_shuffle_down
(
m_a
,
step
);
T
m2_b
=
warp_shuffle_down
(
m2_a
,
step
);
// Update
const
T
n_ab
=
n_a
+
n_b
;
// We can handle one of them being 0, not both.
const
T
rn_ab
=
1.
f
/
n_ab
;
// Might have different n per thread, otherwise this would simplify :(
const
T
delta
=
m_a
-
m_b
;
const
float
m2_ab
=
m2_a
+
m2_b
+
delta
*
delta
*
n_a
*
n_b
*
rn_ab
;
const
float
m_ab
=
(
n_a
*
m_a
+
n_b
*
m_b
)
*
rn_ab
;
n_a
=
n_ab
;
m_a
=
m_ab
;
m2_a
=
m2_ab
;
}
// Intra-warp broadcast (only lane 0 has valid stats).
#if defined(__HIP_PLATFORM_HCC__)
m_a
=
__shfl
(
m_a
,
0
);
m2_a
=
__shfl
(
m2_a
,
0
);
#else
m_a
=
__shfl_sync
(
uint32_t
(
-
1
),
m_a
,
0
);
m2_a
=
__shfl_sync
(
uint32_t
(
-
1
),
m2_a
,
0
);
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
T
,
uint32_t
CTAS_PER_ROW
,
uint32_t
WARPS_M
,
uint32_t
WARPS_N
>
struct
Stats
{
// This could be done generically with the Reducer. But then we would have to exchange 3 instead of 2 fields.
using
InterCTASync
=
InterCTASync
<
CTAS_PER_ROW
>
;
using
BlockStats
=
Stats
<
T
,
1
,
WARPS_M
,
WARPS_N
>
;
using
stats_t
=
typename
BlockStats
::
stats_t
;
enum
{
SMEM_BYTES
=
BlockStats
::
SMEM_BYTES
};
template
<
typename
Params
>
inline
__device__
Stats
(
Params
&
params
,
uint32_t
bidm
,
uint32_t
bidn
,
uint32_t
warp_m
,
uint32_t
warp_n
,
uint32_t
lane
,
void
*
smem
)
:
inter_cta_
(
params
,
bidm
,
bidn
)
,
block_stats_
(
params
,
bidm
,
bidn
,
warp_m
,
warp_n
,
lane
,
smem
)
,
bidn_
(
bidn
)
// CTA id within the group.
,
w0_
(
static_cast
<
stats_t
*>
(
params
.
workspace
)
+
(
bidm
*
WARPS_M
+
warp_m
)
*
CTAS_PER_ROW
)
,
w1_
(
w0_
+
params
.
ctas_per_col
*
WARPS_M
*
CTAS_PER_ROW
)
,
warp_n_
(
warp_n
)
,
lane_
(
lane
)
{
}
template
<
uint32_t
N
>
inline
__device__
stats_t
compute
(
const
T
(
&
elts
)[
N
],
const
T
rn
)
{
constexpr
T
ELTS_PER_ROW_PER_CTA
=
N
*
WARPS_N
*
THREADS_PER_WARP
;
// TODO rn is not really needed here..
constexpr
T
block_rn
=
1.
f
/
T
(
ELTS_PER_ROW_PER_CTA
);
stats_t
block_stats
=
block_stats_
.
compute
(
elts
,
block_rn
);
stats_t
*
workspace
=
inter_cta_
.
phase_counter_
&
0x1
?
w1_
:
w0_
;
if
(
warp_n_
==
0
&&
lane_
==
0
)
{
workspace
[
bidn_
]
=
block_stats
;
}
// Wait for all CTAS_PER_ROW CTAS in the group to have written their result.
inter_cta_
.
sync
();
T
n
=
Zeros
<
T
>::
get
();
T
m
=
Zeros
<
T
>::
get
();
T
m2
=
Zeros
<
T
>::
get
();
// Assume CTA group size in N less than 32, such that we can finalize with a single warp.
static_assert
(
CTAS_PER_ROW
<=
32
);
// Every warp does the final reduction locally.
if
(
lane_
<
CTAS_PER_ROW
)
{
stats_t
result
=
workspace
[
lane_
];
n
=
ELTS_PER_ROW_PER_CTA
;
m
=
layer_norm
::
Get
<
0
>::
of
<
stats_t
,
T
>
(
result
);
m2
=
layer_norm
::
Get
<
1
>::
of
<
stats_t
,
T
>
(
result
);
}
warp_chan_upd_dynamic
(
m
,
m2
,
n
,
CTAS_PER_ROW
);
return
{
m
,
m2
};
}
InterCTASync
inter_cta_
;
BlockStats
block_stats_
;
stats_t
*
w0_
;
stats_t
*
w1_
;
int
bidn_
;
int
warp_n_
;
int
lane_
;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
T
,
uint32_t
WARPS_M
,
uint32_t
WARPS_N
>
struct
Stats
<
T
,
1
,
WARPS_M
,
WARPS_N
>
{
using
WarpStats
=
Stats
<
T
,
1
,
WARPS_M
,
1
>
;
using
stats_t
=
typename
WarpStats
::
stats_t
;
enum
{
SMEM_BYTES
=
WARPS_M
*
WARPS_N
*
sizeof
(
stats_t
)
*
2
};
template
<
typename
Params
>
inline
__device__
Stats
(
Params
&
params
,
uint32_t
bidm
,
uint32_t
bidn
,
uint32_t
warp_m
,
uint32_t
warp_n
,
uint32_t
lane
,
void
*
smem
)
:
warp_stats_
(
params
,
bidm
,
bidn
,
warp_m
,
warp_n
,
lane
,
smem
)
,
use0_
(
true
)
{
smem0_
=
static_cast
<
stats_t
*>
(
smem
)
+
warp_m
*
WARPS_N
;
smem1_
=
smem0_
+
WARPS_M
*
WARPS_N
;
}
template
<
uint32_t
N
>
inline
__device__
stats_t
compute
(
const
T
(
&
elts
)[
N
],
const
T
rn
)
{
stats_t
*
smem
=
use0_
?
smem0_
:
smem1_
;
use0_
=
!
use0_
;
// Compute warp local for all WARPS_N
constexpr
T
warp_rn
=
1.
f
/
T
(
N
*
THREADS_PER_WARP
);
stats_t
warp_stats
=
warp_stats_
.
compute
(
elts
,
warp_rn
);
//Each warp warp leader stores its stats
const
auto
warp_n
=
warp_stats_
.
reducer_
.
warp_n_
;
const
auto
lane
=
warp_stats_
.
reducer_
.
lane_
;
if
(
lane
==
0
)
{
smem
[
warp_n
]
=
warp_stats
;
}
__syncthreads
();
T
n
=
Zeros
<
T
>::
get
();
T
m
=
Zeros
<
T
>::
get
();
T
m2
=
Zeros
<
T
>::
get
();
// Assume that there are less than 32 warps, such that we can finalize with a single warp
static_assert
(
WARPS_N
<=
32
);
if
(
lane
<
WARPS_N
){
stats_t
result
=
smem
[
lane
];
n
=
N
*
THREADS_PER_WARP
;
m
=
layer_norm
::
Get
<
0
>::
of
<
stats_t
,
T
>
(
result
);
m2
=
layer_norm
::
Get
<
1
>::
of
<
stats_t
,
T
>
(
result
);
}
warp_chan_upd_dynamic
(
m
,
m2
,
n
,
WARPS_N
);
return
{
m
,
m2
};
}
WarpStats
warp_stats_
;
stats_t
*
smem0_
;
stats_t
*
smem1_
;
bool
use0_
;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
T
,
uint32_t
WARPS_M
>
struct
Stats
<
T
,
1
,
WARPS_M
,
1
>
{
using
stats_t
=
typename
TypeToVec2
<
T
>::
Type
;
// The simple Warp reducer.
using
Reducer
=
Reducer
<
T
,
1
,
WARPS_M
,
1
>
;
enum
{
SMEM_BYTES
=
0
};
template
<
typename
Params
>
inline
__device__
Stats
(
Params
&
params
,
uint32_t
bidm
,
uint32_t
bidn
,
uint32_t
warp_m
,
uint32_t
warp_n
,
uint32_t
lane
,
void
*
smem
)
:
reducer_
(
params
,
bidm
,
bidn
,
warp_m
,
warp_n
,
lane
,
smem
)
{
}
template
<
uint32_t
N
>
inline
__device__
stats_t
compute
(
const
T
(
&
elts
)[
N
],
const
T
rn
)
{
auto
sum
=
Sum
<
T
>
();
T
m
=
Zeros
<
T
>::
get
();
#pragma unroll
for
(
int
it
=
0
;
it
<
N
;
it
++
)
{
m
+=
elts
[
it
];
}
m
=
reducer_
.
allreduce
(
m
,
sum
)
*
rn
;
T
m2
=
Zeros
<
T
>::
get
();
#pragma unroll
for
(
int
it
=
0
;
it
<
N
;
it
++
)
{
T
diff
=
(
elts
[
it
]
-
m
);
m2
+=
diff
*
diff
;
}
m2
=
reducer_
.
allreduce
(
m2
,
sum
);
return
{
m
,
m2
};
}
Reducer
reducer_
;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
}
// namespace layer_norm
apex/contrib/csrc/multihead_attn/additive_masked_softmax_dropout_cuda.cu
deleted
100644 → 0
View file @
2a4864d5
#include <iostream>
#include <math.h>
#include <vector>
#include <cuda.h>
#include <cuda_fp16.h>
//#include <cuda_profiler_api.h>
#include <cuda_runtime.h>
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include "dropout.cuh"
#include "softmax.cuh"
// symbol to be automatically resolved by PyTorch libs
namespace
multihead_attn
{
namespace
fused_softmax
{
namespace
additive_mask_softmax_dropout
{
std
::
vector
<
torch
::
Tensor
>
fwd_cuda
(
bool
is_training
,
int
heads
,
torch
::
Tensor
const
&
input
,
const
half
*
pad_mask
,
float
dropout_prob
)
{
const
int
attn_batches
=
input
.
size
(
0
);
const
int
sequences
=
attn_batches
/
heads
;
const
int
q_seq_len
=
input
.
size
(
1
);
const
int
k_seq_len
=
q_seq_len
;
// const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
// There is no reason to use more than one stream as every kernel is
// sequentially dependent
cublasHandle_t
handle
=
at
::
cuda
::
getCurrentCUDABlasHandle
();
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
cublasSetStream
(
handle
,
stream
);
// 3 Intermediate Results + Output (Note: dropout intermediates are generated
// by ATen library code)
auto
act_options
=
input
.
options
().
requires_grad
(
false
);
auto
mask_options
=
act_options
.
dtype
(
torch
::
kUInt8
);
torch
::
Tensor
softmax_results
=
torch
::
empty
({
attn_batches
,
q_seq_len
,
k_seq_len
},
act_options
);
torch
::
Tensor
dropout_results
=
torch
::
empty
({
attn_batches
,
q_seq_len
,
k_seq_len
},
act_options
);
torch
::
Tensor
dropout_mask
=
torch
::
empty
({
attn_batches
,
q_seq_len
,
k_seq_len
},
mask_options
);
// Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax)
void
*
input_ptr
=
static_cast
<
void
*>
(
input
.
data_ptr
());
void
*
softmax_results_ptr
=
static_cast
<
void
*>
(
softmax_results
.
data_ptr
());
// Padded Softmax
bool
softmax_success
=
false
;
if
(
pad_mask
==
nullptr
)
{
softmax_success
=
dispatch_softmax
<
half
,
half
,
float
>
(
reinterpret_cast
<
half
*>
(
softmax_results_ptr
),
reinterpret_cast
<
const
half
*>
(
input_ptr
),
k_seq_len
,
k_seq_len
,
attn_batches
*
q_seq_len
);
}
else
{
softmax_success
=
dispatch_additive_masked_softmax
<
half
,
half
,
float
>
(
reinterpret_cast
<
half
*>
(
softmax_results_ptr
),
reinterpret_cast
<
const
half
*>
(
input_ptr
),
pad_mask
,
k_seq_len
,
k_seq_len
,
attn_batches
*
q_seq_len
,
attn_batches
*
q_seq_len
/
sequences
);
}
if
(
is_training
)
{
// use at:: function so that C++ version generates the same random mask as
// python version
auto
dropout_tuple
=
at
::
_fused_dropout
(
softmax_results
,
1.0
f
-
dropout_prob
);
dropout_results
=
std
::
get
<
0
>
(
dropout_tuple
);
dropout_mask
=
std
::
get
<
1
>
(
dropout_tuple
);
}
// Matmul2
return
{
dropout_results
,
dropout_mask
,
softmax_results
};
}
torch
::
Tensor
bwd_cuda
(
int
heads
,
torch
::
Tensor
const
&
output_grads
,
torch
::
Tensor
const
&
softmax_results
,
torch
::
Tensor
const
&
dropout_mask
,
float
dropout_prob
)
{
const
int
attn_batches
=
output_grads
.
size
(
0
);
const
int
q_seq_len
=
output_grads
.
size
(
1
);
const
int
k_seq_len
=
q_seq_len
;
// const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
// TODO: Streams can be used in Backprop but I haven't added more than one
// in my first attempt to create the code
cublasHandle_t
handle
=
at
::
cuda
::
getCurrentCUDABlasHandle
();
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
cublasSetStream
(
handle
,
stream
);
// Output Tensor Allocations
// torch::Tensor input_grads = torch::empty_like(output_grads);
// Apply Dropout Mask and Scale by Dropout Probability
// Softmax Grad
dispatch_masked_scale_softmax_backward_stream
<
half
,
half
,
float
,
false
>
(
static_cast
<
half
*>
(
output_grads
.
data_ptr
()),
static_cast
<
half
*>
(
output_grads
.
data_ptr
()),
reinterpret_cast
<
half
const
*>
(
softmax_results
.
data_ptr
()),
static_cast
<
uint8_t
const
*>
(
dropout_mask
.
data_ptr
()),
1.0
/
(
1.0
-
dropout_prob
),
k_seq_len
,
k_seq_len
,
attn_batches
*
q_seq_len
,
stream
);
// backward pass is completely in-place
return
output_grads
;
}
}
// namespace additive_mask_softmax_dropout
}
// namespace fused_softmax
}
// namespace multihead_attn
apex/contrib/csrc/multihead_attn/dropout.cuh
deleted
100644 → 0
View file @
2a4864d5
#pragma once
#include <ATen/ATen.h>
#ifdef OLD_GENERATOR_PATH
#include <ATen/CUDAGeneratorImpl.h>
#else
#include <ATen/cuda/CUDAGeneratorImpl.h>
#endif
#include <ATen/cuda/CUDAContext.h>
#include <curand_kernel.h>
namespace
{
constexpr
int
UNROLL
=
4
;
}
// namespace
template
<
typename
scalar_t
,
typename
accscalar_t
,
typename
IndexType
>
__global__
void
apex_fused_dropout_kernel
(
scalar_t
const
*
inputs
,
scalar_t
*
outputs
,
uint8_t
*
mask
,
IndexType
totalElements
,
accscalar_t
p
,
std
::
pair
<
uint64_t
,
uint64_t
>
seeds
)
{
accscalar_t
pinv
=
accscalar_t
(
1
)
/
p
;
IndexType
idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
curandStatePhilox4_32_10_t
state
;
curand_init
(
seeds
.
first
,
idx
,
seeds
.
second
,
&
state
);
IndexType
rounded_size
=
((
totalElements
-
1
)
/
(
blockDim
.
x
*
gridDim
.
x
*
UNROLL
)
+
1
)
*
blockDim
.
x
*
gridDim
.
x
*
UNROLL
;
for
(
IndexType
linearIndex
=
idx
;
linearIndex
<
rounded_size
;
linearIndex
+=
gridDim
.
x
*
blockDim
.
x
*
UNROLL
)
{
float4
rand
=
curand_uniform4
(
&
state
);
scalar_t
src
[
UNROLL
];
rand
.
x
=
rand
.
x
<=
p
;
rand
.
y
=
rand
.
y
<=
p
;
rand
.
z
=
rand
.
z
<=
p
;
rand
.
w
=
rand
.
w
<=
p
;
for
(
int
ii
=
0
;
ii
<
UNROLL
;
ii
++
)
{
IndexType
li
=
linearIndex
+
blockDim
.
x
*
gridDim
.
x
*
ii
;
if
(
li
<
totalElements
)
{
src
[
ii
]
=
inputs
[
li
];
}
}
for
(
int
ii
=
0
;
ii
<
UNROLL
;
ii
++
)
{
IndexType
li
=
linearIndex
+
blockDim
.
x
*
gridDim
.
x
*
ii
;
if
(
li
<
totalElements
)
{
outputs
[
li
]
=
src
[
ii
]
*
(
&
rand
.
x
)[
ii
]
*
pinv
;
mask
[
li
]
=
(
uint8_t
)(
&
rand
.
x
)[
ii
];
}
}
__syncthreads
();
}
}
template
<
typename
scalar_t
,
typename
accscalar_t
,
typename
IndexType
>
__global__
void
apex_dropout_add_kernel
(
scalar_t
const
*
inputs
,
scalar_t
const
*
add_inputs
,
scalar_t
*
outputs
,
uint8_t
*
mask
,
IndexType
totalElements
,
accscalar_t
p
,
std
::
pair
<
uint64_t
,
uint64_t
>
seeds
)
{
accscalar_t
pinv
=
accscalar_t
(
1
)
/
p
;
IndexType
idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
curandStatePhilox4_32_10_t
state
;
curand_init
(
seeds
.
first
,
idx
,
seeds
.
second
,
&
state
);
IndexType
rounded_size
=
((
totalElements
-
1
)
/
(
blockDim
.
x
*
gridDim
.
x
*
UNROLL
)
+
1
)
*
blockDim
.
x
*
gridDim
.
x
*
UNROLL
;
for
(
IndexType
linearIndex
=
idx
;
linearIndex
<
rounded_size
;
linearIndex
+=
gridDim
.
x
*
blockDim
.
x
*
UNROLL
)
{
float4
rand
=
curand_uniform4
(
&
state
);
scalar_t
src
[
UNROLL
];
scalar_t
add_src
[
UNROLL
];
rand
.
x
=
rand
.
x
<=
p
;
rand
.
y
=
rand
.
y
<=
p
;
rand
.
z
=
rand
.
z
<=
p
;
rand
.
w
=
rand
.
w
<=
p
;
for
(
int
ii
=
0
;
ii
<
UNROLL
;
ii
++
)
{
IndexType
li
=
linearIndex
+
blockDim
.
x
*
gridDim
.
x
*
ii
;
if
(
li
<
totalElements
)
{
src
[
ii
]
=
inputs
[
li
];
add_src
[
ii
]
=
add_inputs
[
li
];
}
}
for
(
int
ii
=
0
;
ii
<
UNROLL
;
ii
++
)
{
IndexType
li
=
linearIndex
+
blockDim
.
x
*
gridDim
.
x
*
ii
;
if
(
li
<
totalElements
)
{
accscalar_t
int1
=
src
[
ii
]
*
(
&
rand
.
x
)[
ii
]
*
pinv
;
outputs
[
li
]
=
static_cast
<
scalar_t
>
(
static_cast
<
accscalar_t
>
(
add_src
[
ii
])
+
int1
);
mask
[
li
]
=
(
uint8_t
)(
&
rand
.
x
)[
ii
];
}
}
__syncthreads
();
}
}
template
<
typename
scalar_t
,
typename
accscalar_t
,
typename
IndexType
>
__global__
void
apex_add_kernel
(
scalar_t
const
*
inputs
,
scalar_t
const
*
add_inputs
,
scalar_t
*
outputs
,
IndexType
totalElements
)
{
IndexType
idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
IndexType
rounded_size
=
((
totalElements
-
1
)
/
(
blockDim
.
x
*
gridDim
.
x
*
UNROLL
)
+
1
)
*
blockDim
.
x
*
gridDim
.
x
*
UNROLL
;
for
(
IndexType
linearIndex
=
idx
;
linearIndex
<
rounded_size
;
linearIndex
+=
gridDim
.
x
*
blockDim
.
x
*
UNROLL
)
{
scalar_t
src
[
UNROLL
];
scalar_t
add_src
[
UNROLL
];
for
(
int
ii
=
0
;
ii
<
UNROLL
;
ii
++
)
{
IndexType
li
=
linearIndex
+
blockDim
.
x
*
gridDim
.
x
*
ii
;
if
(
li
<
totalElements
)
{
src
[
ii
]
=
inputs
[
li
];
add_src
[
ii
]
=
add_inputs
[
li
];
}
}
for
(
int
ii
=
0
;
ii
<
UNROLL
;
ii
++
)
{
IndexType
li
=
linearIndex
+
blockDim
.
x
*
gridDim
.
x
*
ii
;
if
(
li
<
totalElements
)
{
outputs
[
li
]
=
src
[
ii
]
+
add_src
[
ii
];
}
}
__syncthreads
();
}
}
template
<
typename
scalar_t
,
typename
accscalar_t
,
typename
IndexType
>
__global__
void
apex_masked_scale_kernel
(
scalar_t
const
*
inputs
,
scalar_t
*
outputs
,
uint8_t
const
*
mask
,
IndexType
totalElements
,
accscalar_t
scale
)
{
IndexType
idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
IndexType
rounded_size
=
((
totalElements
-
1
)
/
(
blockDim
.
x
*
gridDim
.
x
*
UNROLL
)
+
1
)
*
blockDim
.
x
*
gridDim
.
x
*
UNROLL
;
for
(
IndexType
linearIndex
=
idx
;
linearIndex
<
rounded_size
;
linearIndex
+=
gridDim
.
x
*
blockDim
.
x
*
UNROLL
)
{
scalar_t
src
[
UNROLL
];
scalar_t
msk
[
UNROLL
];
for
(
int
ii
=
0
;
ii
<
UNROLL
;
ii
++
)
{
IndexType
li
=
linearIndex
+
blockDim
.
x
*
gridDim
.
x
*
ii
;
if
(
li
<
totalElements
)
{
src
[
ii
]
=
static_cast
<
scalar_t
>
(
inputs
[
li
]);
msk
[
ii
]
=
static_cast
<
scalar_t
>
(
mask
[
li
]);
}
}
for
(
int
ii
=
0
;
ii
<
UNROLL
;
ii
++
)
{
IndexType
li
=
linearIndex
+
blockDim
.
x
*
gridDim
.
x
*
ii
;
if
(
li
<
totalElements
)
{
outputs
[
li
]
=
static_cast
<
accscalar_t
>
(
src
[
ii
])
*
scale
*
static_cast
<
accscalar_t
>
(
msk
[
ii
]);
}
}
}
}
template
<
typename
scalar_t
,
typename
accscalar_t
,
typename
IndexType
>
void
apex_fused_dropout_cuda
(
scalar_t
const
*
inputs
,
scalar_t
*
outputs
,
uint8_t
*
mask
,
IndexType
totalElements
,
accscalar_t
p
)
{
auto
gen
=
at
::
cuda
::
detail
::
getDefaultCUDAGenerator
();
int
block_size
=
256
;
dim3
dim_block
(
block_size
);
dim3
grid
((
totalElements
+
block_size
-
1
)
/
block_size
);
unsigned
int
blocks_per_sm
=
at
::
cuda
::
getCurrentDeviceProperties
()
->
maxThreadsPerMultiProcessor
/
block_size
;
grid
.
x
=
std
::
min
((
unsigned
int
)
at
::
cuda
::
getCurrentDeviceProperties
()
->
multiProcessorCount
*
blocks_per_sm
,
grid
.
x
);
// number of times random will be generated per thread, to offset philox
// counter in the random state
int64_t
counter_offset
=
((
totalElements
-
1
)
/
(
block_size
*
grid
.
x
*
UNROLL
)
+
1
)
*
UNROLL
;
std
::
pair
<
uint64_t
,
uint64_t
>
rng_engine_inputs
;
{
// See Note [Acquire lock when using random generators]
std
::
lock_guard
<
std
::
mutex
>
lock
(
gen
.
mutex
());
rng_engine_inputs
=
at
::
check_generator
<
at
::
CUDAGeneratorImpl
>
(
gen
)
->
philox_engine_inputs
(
counter_offset
);
}
apex_fused_dropout_kernel
<
scalar_t
,
accscalar_t
,
IndexType
>
<<<
grid
,
dim_block
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
inputs
,
outputs
,
mask
,
totalElements
,
p
,
rng_engine_inputs
);
C10_CUDA_CHECK
(
cudaGetLastError
());
}
template
<
typename
scalar_t
,
typename
accscalar_t
,
typename
IndexType
>
void
apex_dropout_add_cuda
(
scalar_t
const
*
inputs
,
scalar_t
const
*
add_inputs
,
scalar_t
*
outputs
,
uint8_t
*
mask
,
IndexType
totalElements
,
accscalar_t
p
)
{
auto
gen
=
at
::
cuda
::
detail
::
getDefaultCUDAGenerator
();
int
block_size
=
256
;
dim3
dim_block
(
block_size
);
dim3
grid
((
totalElements
+
block_size
-
1
)
/
block_size
);
unsigned
int
blocks_per_sm
=
at
::
cuda
::
getCurrentDeviceProperties
()
->
maxThreadsPerMultiProcessor
/
block_size
;
grid
.
x
=
std
::
min
((
unsigned
int
)
at
::
cuda
::
getCurrentDeviceProperties
()
->
multiProcessorCount
*
blocks_per_sm
,
grid
.
x
);
// number of times random will be generated per thread, to offset philox
// counter in the random state
int64_t
counter_offset
=
((
totalElements
-
1
)
/
(
block_size
*
grid
.
x
*
UNROLL
)
+
1
)
*
UNROLL
;
std
::
pair
<
uint64_t
,
uint64_t
>
rng_engine_inputs
;
{
// See Note [Acquire lock when using random generators]
std
::
lock_guard
<
std
::
mutex
>
lock
(
gen
.
mutex
());
rng_engine_inputs
=
at
::
check_generator
<
at
::
CUDAGeneratorImpl
>
(
gen
)
->
philox_engine_inputs
(
counter_offset
);
}
apex_dropout_add_kernel
<
scalar_t
,
accscalar_t
,
IndexType
>
<<<
grid
,
dim_block
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
inputs
,
add_inputs
,
outputs
,
mask
,
totalElements
,
p
,
rng_engine_inputs
);
C10_CUDA_CHECK
(
cudaGetLastError
());
}
template
<
typename
scalar_t
,
typename
accscalar_t
,
typename
IndexType
>
void
apex_add_cuda
(
scalar_t
const
*
inputs
,
scalar_t
const
*
add_inputs
,
scalar_t
*
outputs
,
IndexType
totalElements
)
{
int
block_size
=
256
;
dim3
dim_block
(
block_size
);
dim3
grid
((
totalElements
+
block_size
-
1
)
/
block_size
);
unsigned
int
blocks_per_sm
=
at
::
cuda
::
getCurrentDeviceProperties
()
->
maxThreadsPerMultiProcessor
/
block_size
;
grid
.
x
=
std
::
min
((
unsigned
int
)
at
::
cuda
::
getCurrentDeviceProperties
()
->
multiProcessorCount
*
blocks_per_sm
,
grid
.
x
);
apex_add_kernel
<
scalar_t
,
accscalar_t
,
IndexType
>
<<<
grid
,
dim_block
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
inputs
,
add_inputs
,
outputs
,
totalElements
);
C10_CUDA_CHECK
(
cudaGetLastError
());
}
template
<
typename
scalar_t
,
typename
accscalar_t
,
typename
IndexType
>
void
apex_masked_scale_cuda
(
scalar_t
const
*
inputs
,
scalar_t
*
outputs
,
uint8_t
const
*
mask
,
IndexType
totalElements
,
accscalar_t
scale
)
{
int
block_size
=
256
;
dim3
dim_block
(
block_size
);
dim3
grid
((
totalElements
+
block_size
-
1
)
/
block_size
);
unsigned
int
blocks_per_sm
=
at
::
cuda
::
getCurrentDeviceProperties
()
->
maxThreadsPerMultiProcessor
/
block_size
;
grid
.
x
=
std
::
min
((
unsigned
int
)
at
::
cuda
::
getCurrentDeviceProperties
()
->
multiProcessorCount
*
blocks_per_sm
,
grid
.
x
);
apex_masked_scale_kernel
<
scalar_t
,
accscalar_t
,
IndexType
>
<<<
grid
,
dim_block
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
inputs
,
outputs
,
mask
,
totalElements
,
scale
);
C10_CUDA_CHECK
(
cudaGetLastError
());
}
apex/contrib/csrc/multihead_attn/encdec_multihead_attn_cuda.cu
deleted
100644 → 0
View file @
2a4864d5
#include <iostream>
#include <math.h>
#include <vector>
#include <cuda.h>
#include <cuda_fp16.h>
//#include <cuda_profiler_api.h>
#include <cuda_runtime.h>
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include "dropout.cuh"
#include "softmax.cuh"
#include "strided_batched_gemm.cuh"
namespace
multihead_attn
{
namespace
encdec
{
namespace
rocblas_gemmex
{
std
::
vector
<
torch
::
Tensor
>
fwd_cuda
(
bool
use_time_mask
,
bool
is_training
,
int
heads
,
torch
::
Tensor
const
&
inputs_q
,
torch
::
Tensor
const
&
inputs_kv
,
torch
::
Tensor
const
&
input_weights_q
,
torch
::
Tensor
const
&
input_weights_kv
,
torch
::
Tensor
const
&
output_weights
,
const
uint8_t
*
pad_mask
,
float
dropout_prob
)
{
const
int
embed_dim
=
inputs_q
.
size
(
2
);
const
int
sequences
=
inputs_q
.
size
(
1
);
const
int
q_seq_len
=
inputs_q
.
size
(
0
);
const
int
k_seq_len
=
inputs_kv
.
size
(
0
);
const
int
batches_q
=
sequences
*
q_seq_len
;
const
int
batches_kv
=
sequences
*
k_seq_len
;
const
int
head_dim
=
embed_dim
/
heads
;
const
int
output_lin_q_dim
=
embed_dim
;
const
int
output_lin_kv_dim
=
2
*
embed_dim
;
const
int
attn_batches
=
heads
*
sequences
;
const
int
lead_dim_q
=
attn_batches
*
head_dim
;
const
int
lead_dim_kv
=
attn_batches
*
2
*
head_dim
;
const
int
batch_stride_q
=
head_dim
;
const
int
batch_stride_kv
=
2
*
head_dim
;
const
int
dropout_elems
=
attn_batches
*
q_seq_len
*
k_seq_len
;
const
float
alpha
=
1.0
;
const
float
beta
=
0.0
;
const
float
scale
=
1.0
/
sqrt
(
static_cast
<
float
>
(
head_dim
));
// There is no reason to use more than one stream as every kernel is
// sequentially dependent
cublasHandle_t
handle
=
at
::
cuda
::
getCurrentCUDABlasHandle
();
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
cublasSetStream
(
handle
,
stream
);
// 3 Intermediate Results + Output (Note: dropout intermediates are generated
// by ATen library code)
auto
act_options
=
inputs_q
.
options
().
requires_grad
(
false
);
auto
mask_options
=
act_options
.
dtype
(
torch
::
kUInt8
);
torch
::
Tensor
input_lin_q_results
=
torch
::
empty
({
q_seq_len
,
sequences
,
output_lin_q_dim
},
act_options
);
torch
::
Tensor
input_lin_kv_results
=
torch
::
empty
({
k_seq_len
,
sequences
,
output_lin_kv_dim
},
act_options
);
torch
::
Tensor
softmax_results
=
torch
::
empty
({
attn_batches
,
q_seq_len
,
k_seq_len
},
act_options
);
torch
::
Tensor
dropout_results
=
torch
::
empty
({
attn_batches
,
q_seq_len
,
k_seq_len
},
act_options
);
torch
::
Tensor
dropout_mask
=
torch
::
empty
({
attn_batches
,
q_seq_len
,
k_seq_len
},
mask_options
);
torch
::
Tensor
matmul2_results
=
torch
::
empty
({
q_seq_len
,
attn_batches
,
head_dim
},
act_options
);
torch
::
Tensor
outputs
=
torch
::
empty_like
(
inputs_q
,
act_options
);
// Input Linear Results Pointers to Q, K, and V of interviewed activations
void
*
q_lin_results_ptr
=
static_cast
<
void
*>
(
input_lin_q_results
.
data_ptr
());
void
*
k_lin_results_ptr
=
static_cast
<
void
*>
(
input_lin_kv_results
.
data_ptr
());
void
*
v_lin_results_ptr
=
static_cast
<
void
*>
(
static_cast
<
half
*>
(
input_lin_kv_results
.
data_ptr
())
+
head_dim
);
// Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax)
void
*
softmax_results_ptr
=
static_cast
<
void
*>
(
softmax_results
.
data_ptr
());
char
a_layout_t
{
't'
};
char
a_layout_n
{
'n'
};
char
b_layout_n
{
'n'
};
rocblas_int
flags
=
0
;
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Input Linear Q Fwd
TORCH_CUDABLAS_CHECK
(
rocblas_gemm_ex
(
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
output_lin_q_dim
,
batches_q
,
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
input_weights_q
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
const
void
*>
(
inputs_q
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
q_lin_results_ptr
,
rocblas_datatype_f16_r
,
output_lin_q_dim
,
q_lin_results_ptr
,
rocblas_datatype_f16_r
,
output_lin_q_dim
,
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
// Input Linear KV Fwd
TORCH_CUDABLAS_CHECK
(
rocblas_gemm_ex
(
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
output_lin_kv_dim
,
batches_kv
,
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
input_weights_kv
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
const
void
*>
(
inputs_kv
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
k_lin_results_ptr
,
rocblas_datatype_f16_r
,
output_lin_kv_dim
,
k_lin_results_ptr
,
rocblas_datatype_f16_r
,
output_lin_kv_dim
,
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum
(
a_layout_t
,
b_layout_n
,
k_seq_len
,
q_seq_len
,
head_dim
,
scale
,
static_cast
<
const
half
*>
(
k_lin_results_ptr
),
lead_dim_kv
,
batch_stride_kv
,
static_cast
<
const
half
*>
(
q_lin_results_ptr
),
lead_dim_q
,
batch_stride_q
,
beta
,
static_cast
<
half
*>
(
softmax_results_ptr
),
k_seq_len
,
k_seq_len
*
q_seq_len
,
static_cast
<
half
*>
(
softmax_results_ptr
),
k_seq_len
,
k_seq_len
*
q_seq_len
,
attn_batches
,
flags
);
// Padded Softmax
bool
softmax_success
=
false
;
if
(
pad_mask
==
nullptr
)
{
softmax_success
=
dispatch_softmax
<
half
,
half
,
float
>
(
reinterpret_cast
<
half
*>
(
softmax_results_ptr
),
reinterpret_cast
<
const
half
*>
(
softmax_results_ptr
),
k_seq_len
,
k_seq_len
,
attn_batches
*
q_seq_len
);
}
else
{
if
(
use_time_mask
)
{
softmax_success
=
dispatch_time_masked_softmax
<
half
,
half
,
float
>
(
reinterpret_cast
<
half
*>
(
softmax_results_ptr
),
reinterpret_cast
<
const
half
*>
(
softmax_results_ptr
),
pad_mask
,
k_seq_len
,
k_seq_len
,
attn_batches
*
q_seq_len
,
q_seq_len
);
}
else
{
softmax_success
=
dispatch_masked_softmax
<
half
,
half
,
float
>
(
reinterpret_cast
<
half
*>
(
softmax_results_ptr
),
reinterpret_cast
<
const
half
*>
(
softmax_results_ptr
),
pad_mask
,
k_seq_len
,
k_seq_len
,
attn_batches
*
q_seq_len
,
attn_batches
*
q_seq_len
/
sequences
);
}
}
assert
(
softmax_success
);
if
(
is_training
)
{
apex_fused_dropout_cuda
<
at
::
Half
,
float
,
uint32_t
>
(
static_cast
<
at
::
Half
const
*>
(
softmax_results
.
data_ptr
()),
static_cast
<
at
::
Half
*>
(
dropout_results
.
data_ptr
()),
static_cast
<
uint8_t
*>
(
dropout_mask
.
data_ptr
()),
dropout_elems
,
(
1.0
f
-
dropout_prob
));
}
// Matmul2
gemm_switch_fp32accum
(
a_layout_n
,
b_layout_n
,
head_dim
,
q_seq_len
,
k_seq_len
,
alpha
,
static_cast
<
const
half
*>
(
v_lin_results_ptr
),
lead_dim_kv
,
batch_stride_kv
,
(
is_training
)
?
static_cast
<
const
half
*>
(
dropout_results
.
data_ptr
())
:
static_cast
<
const
half
*>
(
softmax_results
.
data_ptr
())
,
k_seq_len
,
k_seq_len
*
q_seq_len
,
beta
,
static_cast
<
half
*>
(
matmul2_results
.
data_ptr
()),
head_dim
*
attn_batches
,
head_dim
,
static_cast
<
half
*>
(
matmul2_results
.
data_ptr
()),
head_dim
*
attn_batches
,
head_dim
,
attn_batches
,
flags
);
// Output Linear
TORCH_CUDABLAS_CHECK
(
rocblas_gemm_ex
(
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
embed_dim
,
batches_q
,
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
outputs
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
void
*>
(
outputs
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return
{
input_lin_q_results
,
input_lin_kv_results
,
softmax_results
,
dropout_results
,
dropout_mask
,
matmul2_results
,
outputs
};
}
std
::
vector
<
torch
::
Tensor
>
bwd_cuda
(
int
heads
,
torch
::
Tensor
const
&
output_grads
,
torch
::
Tensor
const
&
matmul2_results
,
torch
::
Tensor
const
&
dropout_results
,
torch
::
Tensor
const
&
softmax_results
,
torch
::
Tensor
const
&
input_lin_q_results
,
torch
::
Tensor
const
&
input_lin_kv_results
,
torch
::
Tensor
const
&
inputs_q
,
torch
::
Tensor
const
&
inputs_kv
,
torch
::
Tensor
const
&
input_weights_q
,
torch
::
Tensor
const
&
input_weights_kv
,
torch
::
Tensor
const
&
output_weights
,
torch
::
Tensor
const
&
dropout_mask
,
float
dropout_prob
)
{
const
int
embed_dim
=
inputs_q
.
size
(
2
);
const
int
sequences
=
inputs_q
.
size
(
1
);
const
int
q_seq_len
=
inputs_q
.
size
(
0
);
const
int
k_seq_len
=
inputs_kv
.
size
(
0
);
const
int
batches_q
=
sequences
*
q_seq_len
;
const
int
batches_kv
=
sequences
*
k_seq_len
;
const
int
head_dim
=
embed_dim
/
heads
;
const
int
output_lin_q_dim
=
embed_dim
;
const
int
output_lin_kv_dim
=
2
*
embed_dim
;
const
int
attn_batches
=
heads
*
sequences
;
const
int
lead_dim_q
=
attn_batches
*
head_dim
;
const
int
lead_dim_kv
=
attn_batches
*
2
*
head_dim
;
const
int
batch_stride_q
=
head_dim
;
const
int
batch_stride_kv
=
2
*
head_dim
;
const
int
dropout_elems
=
attn_batches
*
q_seq_len
*
k_seq_len
;
const
float
alpha
=
1.0
;
const
float
beta
=
0.0
;
const
float
scale
=
1.0
/
sqrt
(
static_cast
<
float
>
(
head_dim
));
// TODO: Streams can be used in Backprop but I haven't added more than one
// in my first attempt to create the code
cublasHandle_t
handle
=
at
::
cuda
::
getCurrentCUDABlasHandle
();
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
cublasSetStream
(
handle
,
stream
);
// Output Tensor Allocations
torch
::
Tensor
input_q_grads
=
torch
::
empty_like
(
inputs_q
);
torch
::
Tensor
input_kv_grads
=
torch
::
empty_like
(
inputs_kv
);
torch
::
Tensor
input_weight_q_grads
=
torch
::
empty_like
(
input_weights_q
);
torch
::
Tensor
input_weight_kv_grads
=
torch
::
empty_like
(
input_weights_kv
);
torch
::
Tensor
output_weight_grads
=
torch
::
empty_like
(
output_weights
);
// Intermediate Tensor Allocations
at
::
Tensor
output_lin_grads
=
torch
::
empty_like
(
matmul2_results
);
at
::
Tensor
matmul2_grads
=
torch
::
empty_like
(
dropout_results
);
at
::
Tensor
input_lin_q_output_grads
=
torch
::
empty_like
(
input_lin_q_results
);
at
::
Tensor
input_lin_kv_output_grads
=
torch
::
empty_like
(
input_lin_kv_results
);
auto
q_lin_results_ptr
=
static_cast
<
half
*>
(
input_lin_q_results
.
data_ptr
());
auto
k_lin_results_ptr
=
static_cast
<
half
*>
(
input_lin_kv_results
.
data_ptr
());
auto
v_lin_results_ptr
=
static_cast
<
half
*>
(
input_lin_kv_results
.
data_ptr
())
+
head_dim
;
auto
q_lin_grads_ptr
=
static_cast
<
half
*>
(
input_lin_q_output_grads
.
data_ptr
());
auto
k_lin_grads_ptr
=
static_cast
<
half
*>
(
input_lin_kv_output_grads
.
data_ptr
());
auto
v_lin_grads_ptr
=
static_cast
<
half
*>
(
input_lin_kv_output_grads
.
data_ptr
())
+
head_dim
;
char
a_layout_n
{
'n'
};
char
a_layout_t
{
't'
};
char
b_layout_n
{
'n'
};
char
b_layout_t
{
't'
};
rocblas_int
flags
=
0
;
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
#ifdef __HIP_PLATFORM_HCC__
#define PYTORCH_ROCBLAS_VERSION_DECIMAL (ROCBLAS_VERSION_MAJOR * 100 + ROCBLAS_VERSION_MINOR)
#define USE_GEMM_FLAGS_FP16_ALT_IMPL (PYTORCH_ROCBLAS_VERSION_DECIMAL >= 242)
#if USE_GEMM_FLAGS_FP16_ALT_IMPL
#ifdef BACKWARD_PASS_GUARD
flags
=
at
::
BACKWARD_PASS_GUARD_CLASS
::
is_backward_pass
()
?
rocblas_gemm_flags_fp16_alt_impl
:
0
;
#endif
#endif
#endif
// Output Linear Dgrad
TORCH_CUDABLAS_CHECK
(
rocblas_gemm_ex
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
embed_dim
,
batches_q
,
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
const
void
*>
(
output_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
// Output Linear Wgrad
TORCH_CUDABLAS_CHECK
(
rocblas_gemm_ex
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
embed_dim
,
embed_dim
,
batches_q
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
const
void
*>
(
output_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
// MatMul2 Dgrad1
gemm_switch_fp32accum
(
a_layout_t
,
b_layout_n
,
k_seq_len
,
q_seq_len
,
head_dim
,
alpha
,
static_cast
<
const
half
*>
(
v_lin_results_ptr
),
lead_dim_kv
,
batch_stride_kv
,
static_cast
<
const
half
*>
(
output_lin_grads
.
data_ptr
()),
head_dim
*
attn_batches
,
head_dim
,
beta
,
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
k_seq_len
,
k_seq_len
*
q_seq_len
,
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
k_seq_len
,
k_seq_len
*
q_seq_len
,
attn_batches
,
flags
);
// Matmul2 Dgrad2
gemm_switch_fp32accum
(
a_layout_n
,
b_layout_t
,
head_dim
,
k_seq_len
,
q_seq_len
,
alpha
,
static_cast
<
const
half
*>
(
output_lin_grads
.
data_ptr
()),
head_dim
*
attn_batches
,
head_dim
,
static_cast
<
const
half
*>
(
dropout_results
.
data_ptr
()),
k_seq_len
,
k_seq_len
*
q_seq_len
,
beta
,
v_lin_grads_ptr
,
lead_dim_kv
,
batch_stride_kv
,
v_lin_grads_ptr
,
lead_dim_kv
,
batch_stride_kv
,
attn_batches
,
flags
);
// Apply Dropout Mask and Scale by Dropout Probability
apex_masked_scale_cuda
<
at
::
Half
,
float
,
uint32_t
>
(
static_cast
<
at
::
Half
const
*>
(
matmul2_grads
.
data_ptr
()),
static_cast
<
at
::
Half
*>
(
matmul2_grads
.
data_ptr
()),
static_cast
<
uint8_t
const
*>
(
dropout_mask
.
data_ptr
()),
dropout_elems
,
(
1.0
/
(
1.0
-
dropout_prob
)));
// Softmax Grad
bool
softmax_success
=
false
;
softmax_success
=
dispatch_softmax_backward
<
half
,
half
,
float
>
(
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
reinterpret_cast
<
half
const
*>
(
softmax_results
.
data_ptr
()),
k_seq_len
,
k_seq_len
,
attn_batches
*
q_seq_len
);
assert
(
softmax_success
);
// Matmul1 Dgrad1
gemm_switch_fp32accum
(
a_layout_n
,
b_layout_n
,
head_dim
,
q_seq_len
,
k_seq_len
,
scale
,
k_lin_results_ptr
,
lead_dim_kv
,
batch_stride_kv
,
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
k_seq_len
,
k_seq_len
*
q_seq_len
,
beta
,
q_lin_grads_ptr
,
lead_dim_q
,
batch_stride_q
,
q_lin_grads_ptr
,
lead_dim_q
,
batch_stride_q
,
attn_batches
,
flags
);
// Matmul1 Dgrad2
gemm_switch_fp32accum
(
a_layout_n
,
b_layout_t
,
head_dim
,
k_seq_len
,
q_seq_len
,
scale
,
q_lin_results_ptr
,
lead_dim_q
,
batch_stride_q
,
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
k_seq_len
,
k_seq_len
*
q_seq_len
,
beta
,
k_lin_grads_ptr
,
lead_dim_kv
,
batch_stride_kv
,
k_lin_grads_ptr
,
lead_dim_kv
,
batch_stride_kv
,
attn_batches
,
flags
);
// Input Linear Q Dgrad
TORCH_CUDABLAS_CHECK
(
rocblas_gemm_ex
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
embed_dim
,
batches_q
,
output_lin_q_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
input_weights_q
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
const
void
*>
(
q_lin_grads_ptr
),
rocblas_datatype_f16_r
,
output_lin_q_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
input_q_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
void
*>
(
input_q_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
// Input Linear Q Wgrad
TORCH_CUDABLAS_CHECK
(
rocblas_gemm_ex
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
embed_dim
,
output_lin_q_dim
,
batches_q
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
inputs_q
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
const
void
*>
(
q_lin_grads_ptr
),
rocblas_datatype_f16_r
,
output_lin_q_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
input_weight_q_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
void
*>
(
input_weight_q_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
// Input Linear KV Dgrad
TORCH_CUDABLAS_CHECK
(
rocblas_gemm_ex
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
embed_dim
,
batches_kv
,
output_lin_kv_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
input_weights_kv
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
const
void
*>
(
k_lin_grads_ptr
),
rocblas_datatype_f16_r
,
output_lin_kv_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
input_kv_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
void
*>
(
input_kv_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
// Input Linear KV Wgrad
TORCH_CUDABLAS_CHECK
(
rocblas_gemm_ex
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
embed_dim
,
output_lin_kv_dim
,
batches_kv
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
inputs_kv
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
const
void
*>
(
k_lin_grads_ptr
),
rocblas_datatype_f16_r
,
output_lin_kv_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
input_weight_kv_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
void
*>
(
input_weight_kv_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
// TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return
{
input_q_grads
,
input_kv_grads
,
input_weight_q_grads
,
input_weight_kv_grads
,
output_weight_grads
};
}
}
// end namespace rocblas_gemmex
}
// end namespace encdec
}
// end namespace multihead_attn
apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add_cuda.cu
deleted
100644 → 0
View file @
2a4864d5
#include <iostream>
#include <math.h>
#include <vector>
#include <cuda.h>
#include <cuda_fp16.h>
//#include <cuda_profiler_api.h>
#include <cuda_runtime.h>
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include "dropout.cuh"
#include "layer_norm.cuh"
#include "softmax.cuh"
#include "strided_batched_gemm.cuh"
namespace
multihead_attn
{
namespace
encdec_norm_add
{
namespace
rocblas_gemmex
{
std
::
vector
<
torch
::
Tensor
>
fwd_cuda
(
bool
use_time_mask
,
bool
is_training
,
int
heads
,
torch
::
Tensor
const
&
inputs_q
,
torch
::
Tensor
const
&
inputs_kv
,
torch
::
Tensor
const
&
lyr_nrm_gamma_weights
,
torch
::
Tensor
const
&
lyr_nrm_beta_weights
,
torch
::
Tensor
const
&
input_weights_q
,
torch
::
Tensor
const
&
input_weights_kv
,
torch
::
Tensor
const
&
output_weights
,
const
uint8_t
*
pad_mask
,
float
dropout_prob
)
{
const
int
embed_dim
=
inputs_q
.
size
(
2
);
const
int
sequences
=
inputs_q
.
size
(
1
);
const
int
q_seq_len
=
inputs_q
.
size
(
0
);
const
int
k_seq_len
=
inputs_kv
.
size
(
0
);
const
int
batches_q
=
sequences
*
q_seq_len
;
const
int
batches_kv
=
sequences
*
k_seq_len
;
const
int
total_tokens_q
=
batches_q
*
embed_dim
;
const
int
head_dim
=
embed_dim
/
heads
;
const
int
output_lin_q_dim
=
embed_dim
;
const
int
output_lin_kv_dim
=
2
*
embed_dim
;
const
int
attn_batches
=
heads
*
sequences
;
const
int
lead_dim_q
=
attn_batches
*
head_dim
;
const
int
lead_dim_kv
=
attn_batches
*
2
*
head_dim
;
const
int
batch_stride_q
=
head_dim
;
const
int
batch_stride_kv
=
2
*
head_dim
;
const
int
dropout_elems
=
attn_batches
*
q_seq_len
*
k_seq_len
;
const
float
alpha
=
1.0
;
const
float
beta
=
0.0
;
const
float
scale
=
1.0
/
sqrt
(
static_cast
<
float
>
(
head_dim
));
// There is no reason to use more than one stream as every kernel is
// sequentially dependent
cublasHandle_t
handle
=
at
::
cuda
::
getCurrentCUDABlasHandle
();
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
cublasSetStream
(
handle
,
stream
);
// 3 Intermediate Results + Output (Note: dropout intermediates are generated
// by ATen library code)
auto
act_options
=
inputs_q
.
options
().
requires_grad
(
false
);
auto
lyr_nrm_options
=
act_options
.
dtype
(
torch
::
kFloat32
);
auto
mask_options
=
act_options
.
dtype
(
torch
::
kUInt8
);
torch
::
Tensor
lyr_nrm_mean
=
torch
::
empty
({
batches_q
},
lyr_nrm_options
);
torch
::
Tensor
lyr_nrm_invvar
=
torch
::
empty
({
batches_q
},
lyr_nrm_options
);
torch
::
Tensor
lyr_nrm_results
=
torch
::
empty_like
(
inputs_q
,
act_options
);
torch
::
Tensor
input_lin_q_results
=
torch
::
empty
({
q_seq_len
,
sequences
,
output_lin_q_dim
},
act_options
);
torch
::
Tensor
input_lin_kv_results
=
torch
::
empty
({
k_seq_len
,
sequences
,
output_lin_kv_dim
},
act_options
);
torch
::
Tensor
softmax_results
=
torch
::
empty
({
attn_batches
,
q_seq_len
,
k_seq_len
},
act_options
);
torch
::
Tensor
dropout_results
=
torch
::
empty
({
attn_batches
,
q_seq_len
,
k_seq_len
},
act_options
);
torch
::
Tensor
dropout_mask
=
torch
::
empty
({
attn_batches
,
q_seq_len
,
k_seq_len
},
mask_options
);
torch
::
Tensor
matmul2_results
=
torch
::
empty
({
q_seq_len
,
attn_batches
,
head_dim
},
act_options
);
torch
::
Tensor
output_lin_results
=
torch
::
empty_like
(
inputs_q
,
act_options
);
torch
::
Tensor
dropout_add_mask
=
torch
::
empty_like
(
inputs_q
,
mask_options
);
torch
::
Tensor
outputs
=
torch
::
empty_like
(
inputs_q
,
act_options
);
// Input Linear Results Pointers to Q, K, and V of interviewed activations
void
*
q_lin_results_ptr
=
static_cast
<
void
*>
(
input_lin_q_results
.
data_ptr
());
void
*
k_lin_results_ptr
=
static_cast
<
void
*>
(
input_lin_kv_results
.
data_ptr
());
void
*
v_lin_results_ptr
=
static_cast
<
void
*>
(
static_cast
<
half
*>
(
input_lin_kv_results
.
data_ptr
())
+
head_dim
);
// Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax)
void
*
softmax_results_ptr
=
static_cast
<
void
*>
(
softmax_results
.
data_ptr
());
char
a_layout_t
{
't'
};
char
a_layout_n
{
'n'
};
char
b_layout_n
{
'n'
};
rocblas_int
flags
=
0
;
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Layer Norm
HostApplyLayerNorm
<
at
::
Half
,
float
>
(
static_cast
<
at
::
Half
*>
(
lyr_nrm_results
.
data_ptr
()),
static_cast
<
float
*>
(
lyr_nrm_mean
.
data_ptr
()),
static_cast
<
float
*>
(
lyr_nrm_invvar
.
data_ptr
()),
static_cast
<
const
at
::
Half
*>
(
inputs_q
.
data_ptr
()),
static_cast
<
int
>
(
batches_q
),
// n1
static_cast
<
int
>
(
embed_dim
),
// n2
1.0e-5
,
static_cast
<
const
at
::
Half
*>
(
lyr_nrm_gamma_weights
.
data_ptr
()),
static_cast
<
const
at
::
Half
*>
(
lyr_nrm_beta_weights
.
data_ptr
()));
// Input Linear Q Fwd
TORCH_CUDABLAS_CHECK
(
rocblas_gemm_ex
(
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
output_lin_q_dim
,
batches_q
,
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
input_weights_q
.
data_ptr
()),
rocblas_datatype_f16_r
/*a_type*/
,
embed_dim
,
//static_cast<const void*>(inputs_q.data_ptr()),
static_cast
<
const
void
*>
(
lyr_nrm_results
.
data_ptr
()),
rocblas_datatype_f16_r
/*b_type*/
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
q_lin_results_ptr
,
rocblas_datatype_f16_r
/*c_type*/
,
output_lin_q_dim
,
q_lin_results_ptr
,
rocblas_datatype_f16_r
/*d_type*/
,
output_lin_q_dim
,
rocblas_datatype_f32_r
/*compute_type*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
// Input Linear KV Fwd
TORCH_CUDABLAS_CHECK
(
rocblas_gemm_ex
(
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
output_lin_kv_dim
,
batches_kv
,
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
input_weights_kv
.
data_ptr
()),
rocblas_datatype_f16_r
/*a_type*/
,
embed_dim
,
static_cast
<
const
void
*>
(
inputs_kv
.
data_ptr
()),
rocblas_datatype_f16_r
/*b_type*/
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
k_lin_results_ptr
,
rocblas_datatype_f16_r
/*c_type*/
,
output_lin_kv_dim
,
k_lin_results_ptr
,
rocblas_datatype_f16_r
/*d_type*/
,
output_lin_kv_dim
,
rocblas_datatype_f32_r
/*compute_type*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum
(
a_layout_t
,
b_layout_n
,
k_seq_len
,
q_seq_len
,
head_dim
,
scale
,
static_cast
<
const
half
*>
(
k_lin_results_ptr
),
lead_dim_kv
,
batch_stride_kv
,
static_cast
<
const
half
*>
(
q_lin_results_ptr
),
lead_dim_q
,
batch_stride_q
,
beta
,
static_cast
<
half
*>
(
softmax_results_ptr
),
k_seq_len
,
k_seq_len
*
q_seq_len
,
static_cast
<
half
*>
(
softmax_results_ptr
),
k_seq_len
,
k_seq_len
*
q_seq_len
,
attn_batches
,
flags
);
// Padded Softmax
bool
softmax_success
=
false
;
if
(
pad_mask
==
nullptr
)
{
softmax_success
=
dispatch_softmax
<
half
,
half
,
float
>
(
reinterpret_cast
<
half
*>
(
softmax_results_ptr
),
reinterpret_cast
<
const
half
*>
(
softmax_results_ptr
),
k_seq_len
,
k_seq_len
,
attn_batches
*
q_seq_len
);
}
else
{
if
(
use_time_mask
)
{
softmax_success
=
dispatch_time_masked_softmax
<
half
,
half
,
float
>
(
reinterpret_cast
<
half
*>
(
softmax_results_ptr
),
reinterpret_cast
<
const
half
*>
(
softmax_results_ptr
),
pad_mask
,
k_seq_len
,
k_seq_len
,
attn_batches
*
q_seq_len
,
q_seq_len
);
}
else
{
softmax_success
=
dispatch_masked_softmax
<
half
,
half
,
float
>
(
reinterpret_cast
<
half
*>
(
softmax_results_ptr
),
reinterpret_cast
<
const
half
*>
(
softmax_results_ptr
),
pad_mask
,
k_seq_len
,
k_seq_len
,
attn_batches
*
q_seq_len
,
attn_batches
*
q_seq_len
/
sequences
);
}
}
assert
(
softmax_success
);
if
(
is_training
)
{
apex_fused_dropout_cuda
<
at
::
Half
,
float
,
uint32_t
>
(
static_cast
<
at
::
Half
const
*>
(
softmax_results
.
data_ptr
()),
static_cast
<
at
::
Half
*>
(
dropout_results
.
data_ptr
()),
static_cast
<
uint8_t
*>
(
dropout_mask
.
data_ptr
()),
dropout_elems
,
(
1.0
f
-
dropout_prob
));
}
// Matmul2
gemm_switch_fp32accum
(
a_layout_n
,
b_layout_n
,
head_dim
,
q_seq_len
,
k_seq_len
,
alpha
,
static_cast
<
const
half
*>
(
v_lin_results_ptr
),
lead_dim_kv
,
batch_stride_kv
,
(
is_training
)
?
static_cast
<
const
half
*>
(
dropout_results
.
data_ptr
())
:
static_cast
<
const
half
*>
(
softmax_results
.
data_ptr
()),
//static_cast<const half*>(dropout_results.data_ptr()),
k_seq_len
,
k_seq_len
*
q_seq_len
,
beta
,
static_cast
<
half
*>
(
matmul2_results
.
data_ptr
()),
head_dim
*
attn_batches
,
head_dim
,
static_cast
<
half
*>
(
matmul2_results
.
data_ptr
()),
head_dim
*
attn_batches
,
head_dim
,
attn_batches
,
flags
);
// Output Linear
TORCH_CUDABLAS_CHECK
(
rocblas_gemm_ex
(
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
embed_dim
,
batches_q
,
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
rocblas_datatype_f16_r
/*a_type*/
,
embed_dim
,
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
rocblas_datatype_f16_r
/*b_type*/
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
output_lin_results
.
data_ptr
()),
rocblas_datatype_f16_r
/*c_type*/
,
embed_dim
,
static_cast
<
void
*>
(
output_lin_results
.
data_ptr
()),
rocblas_datatype_f16_r
/*d_type*/
,
embed_dim
,
rocblas_datatype_f32_r
/*compute_type*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
// End-of-block Dropout-Add
if
(
is_training
)
{
apex_dropout_add_cuda
<
at
::
Half
,
float
,
uint32_t
>
(
static_cast
<
at
::
Half
const
*>
(
output_lin_results
.
data_ptr
()),
static_cast
<
at
::
Half
const
*>
(
inputs_q
.
data_ptr
()),
static_cast
<
at
::
Half
*>
(
outputs
.
data_ptr
()),
static_cast
<
uint8_t
*>
(
dropout_add_mask
.
data_ptr
()),
total_tokens_q
,
(
1.0
f
-
dropout_prob
));
}
else
{
apex_add_cuda
<
at
::
Half
,
float
,
uint32_t
>
(
static_cast
<
at
::
Half
const
*>
(
output_lin_results
.
data_ptr
()),
static_cast
<
at
::
Half
const
*>
(
inputs_q
.
data_ptr
()),
static_cast
<
at
::
Half
*>
(
outputs
.
data_ptr
()),
total_tokens_q
);
}
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return
{
lyr_nrm_results
,
lyr_nrm_mean
,
lyr_nrm_invvar
,
input_lin_q_results
,
input_lin_kv_results
,
softmax_results
,
dropout_results
,
dropout_mask
,
matmul2_results
,
dropout_add_mask
,
outputs
};
}
std
::
vector
<
torch
::
Tensor
>
bwd_cuda
(
int
heads
,
torch
::
Tensor
const
&
output_grads
,
torch
::
Tensor
const
&
matmul2_results
,
torch
::
Tensor
const
&
dropout_results
,
torch
::
Tensor
const
&
softmax_results
,
torch
::
Tensor
const
&
input_lin_q_results
,
torch
::
Tensor
const
&
input_lin_kv_results
,
torch
::
Tensor
const
&
lyr_nrm_results
,
torch
::
Tensor
const
&
lyr_nrm_mean
,
torch
::
Tensor
const
&
lyr_nrm_invvar
,
torch
::
Tensor
const
&
inputs_q
,
torch
::
Tensor
const
&
inputs_kv
,
torch
::
Tensor
const
&
lyr_nrm_gamma_weights
,
torch
::
Tensor
const
&
lyr_nrm_beta_weights
,
torch
::
Tensor
const
&
input_weights_q
,
torch
::
Tensor
const
&
input_weights_kv
,
torch
::
Tensor
const
&
output_weights
,
torch
::
Tensor
const
&
dropout_mask
,
torch
::
Tensor
const
&
dropout_add_mask
,
float
dropout_prob
)
{
const
int
embed_dim
=
inputs_q
.
size
(
2
);
const
int
sequences
=
inputs_q
.
size
(
1
);
const
int
q_seq_len
=
inputs_q
.
size
(
0
);
const
int
k_seq_len
=
inputs_kv
.
size
(
0
);
const
int
batches_q
=
sequences
*
q_seq_len
;
const
int
batches_kv
=
sequences
*
k_seq_len
;
const
int
total_tokens_q
=
batches_q
*
embed_dim
;
const
int
head_dim
=
embed_dim
/
heads
;
const
int
output_lin_q_dim
=
embed_dim
;
const
int
output_lin_kv_dim
=
2
*
embed_dim
;
const
int
attn_batches
=
heads
*
sequences
;
const
int
lead_dim_q
=
attn_batches
*
head_dim
;
const
int
lead_dim_kv
=
attn_batches
*
2
*
head_dim
;
const
int
batch_stride_q
=
head_dim
;
const
int
batch_stride_kv
=
2
*
head_dim
;
const
int
dropout_elems
=
attn_batches
*
q_seq_len
*
k_seq_len
;
const
float
alpha
=
1.0
;
const
float
beta
=
0.0
;
const
float
scale
=
1.0
/
sqrt
(
static_cast
<
float
>
(
head_dim
));
// TODO: Streams can be used in Backprop but I haven't added more than one
// in my first attempt to create the code
cublasHandle_t
handle
=
at
::
cuda
::
getCurrentCUDABlasHandle
();
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
cublasSetStream
(
handle
,
stream
);
// Output Tensor Allocations
torch
::
Tensor
input_q_grads
=
torch
::
empty_like
(
inputs_q
);
torch
::
Tensor
input_kv_grads
=
torch
::
empty_like
(
inputs_kv
);
torch
::
Tensor
lyr_nrm_gamma_grads
=
torch
::
empty_like
(
lyr_nrm_gamma_weights
);
torch
::
Tensor
lyr_nrm_beta_grads
=
torch
::
empty_like
(
lyr_nrm_beta_weights
);
torch
::
Tensor
input_weight_q_grads
=
torch
::
empty_like
(
input_weights_q
);
torch
::
Tensor
input_weight_kv_grads
=
torch
::
empty_like
(
input_weights_kv
);
torch
::
Tensor
output_weight_grads
=
torch
::
empty_like
(
output_weights
);
// Intermediate Tensor Allocations
at
::
Tensor
dropout_add_grads
=
torch
::
empty_like
(
output_grads
);
at
::
Tensor
output_lin_grads
=
torch
::
empty_like
(
matmul2_results
);
at
::
Tensor
matmul2_grads
=
torch
::
empty_like
(
dropout_results
);
at
::
Tensor
input_lin_q_output_grads
=
torch
::
empty_like
(
input_lin_q_results
);
at
::
Tensor
input_lin_kv_output_grads
=
torch
::
empty_like
(
input_lin_kv_results
);
at
::
Tensor
input_lin_q_grads
=
torch
::
empty_like
(
inputs_q
);
auto
q_lin_results_ptr
=
static_cast
<
half
*>
(
input_lin_q_results
.
data_ptr
());
auto
k_lin_results_ptr
=
static_cast
<
half
*>
(
input_lin_kv_results
.
data_ptr
());
auto
v_lin_results_ptr
=
static_cast
<
half
*>
(
input_lin_kv_results
.
data_ptr
())
+
head_dim
;
auto
q_lin_grads_ptr
=
static_cast
<
half
*>
(
input_lin_q_output_grads
.
data_ptr
());
auto
k_lin_grads_ptr
=
static_cast
<
half
*>
(
input_lin_kv_output_grads
.
data_ptr
());
auto
v_lin_grads_ptr
=
static_cast
<
half
*>
(
input_lin_kv_output_grads
.
data_ptr
())
+
head_dim
;
char
a_layout_n
{
'n'
};
char
a_layout_t
{
't'
};
char
b_layout_n
{
'n'
};
char
b_layout_t
{
't'
};
rocblas_int
flags
=
0
;
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
#ifdef __HIP_PLATFORM_HCC__
#define PYTORCH_ROCBLAS_VERSION_DECIMAL (ROCBLAS_VERSION_MAJOR * 100 + ROCBLAS_VERSION_MINOR)
#define USE_GEMM_FLAGS_FP16_ALT_IMPL (PYTORCH_ROCBLAS_VERSION_DECIMAL >= 242)
#if USE_GEMM_FLAGS_FP16_ALT_IMPL
#ifdef BACKWARD_PASS_GUARD
flags
=
at
::
BACKWARD_PASS_GUARD_CLASS
::
is_backward_pass
()
?
rocblas_gemm_flags_fp16_alt_impl
:
0
;
#endif
#endif
#endif
// Dropout Add Backward
apex_masked_scale_cuda
<
at
::
Half
,
float
,
uint32_t
>
(
static_cast
<
at
::
Half
const
*>
(
output_grads
.
data_ptr
()),
static_cast
<
at
::
Half
*>
(
dropout_add_grads
.
data_ptr
()),
static_cast
<
uint8_t
const
*>
(
dropout_add_mask
.
data_ptr
()),
total_tokens_q
,
(
1.0
/
(
1.0
-
dropout_prob
)));
// Output Linear Dgrad
TORCH_CUDABLAS_CHECK
(
rocblas_gemm_ex
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
embed_dim
,
batches_q
,
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
rocblas_datatype_f16_r
/*a_type*/
,
embed_dim
,
static_cast
<
const
void
*>
(
dropout_add_grads
.
data_ptr
()),
rocblas_datatype_f16_r
/*b_type*/
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
rocblas_datatype_f16_r
/*c_type*/
,
embed_dim
,
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
rocblas_datatype_f16_r
/*d_type*/
,
embed_dim
,
rocblas_datatype_f32_r
/*compute_type*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
// Output Linear Wgrad
TORCH_CUDABLAS_CHECK
(
rocblas_gemm_ex
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
embed_dim
,
embed_dim
,
batches_q
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
rocblas_datatype_f16_r
/*a_type*/
,
embed_dim
,
static_cast
<
const
void
*>
(
dropout_add_grads
.
data_ptr
()),
rocblas_datatype_f16_r
/*b_type*/
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
rocblas_datatype_f16_r
/*c_type*/
,
embed_dim
,
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
rocblas_datatype_f16_r
/*d_type*/
,
embed_dim
,
rocblas_datatype_f32_r
/*compute_type*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
// MatMul2 Dgrad1
gemm_switch_fp32accum
(
a_layout_t
,
b_layout_n
,
k_seq_len
,
q_seq_len
,
head_dim
,
alpha
,
static_cast
<
const
half
*>
(
v_lin_results_ptr
),
lead_dim_kv
,
batch_stride_kv
,
static_cast
<
const
half
*>
(
output_lin_grads
.
data_ptr
()),
head_dim
*
attn_batches
,
head_dim
,
beta
,
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
k_seq_len
,
k_seq_len
*
q_seq_len
,
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
k_seq_len
,
k_seq_len
*
q_seq_len
,
attn_batches
,
flags
);
// Matmul2 Dgrad2
gemm_switch_fp32accum
(
a_layout_n
,
b_layout_t
,
head_dim
,
k_seq_len
,
q_seq_len
,
alpha
,
static_cast
<
const
half
*>
(
output_lin_grads
.
data_ptr
()),
head_dim
*
attn_batches
,
head_dim
,
static_cast
<
const
half
*>
(
dropout_results
.
data_ptr
()),
k_seq_len
,
k_seq_len
*
q_seq_len
,
beta
,
v_lin_grads_ptr
,
lead_dim_kv
,
batch_stride_kv
,
v_lin_grads_ptr
,
lead_dim_kv
,
batch_stride_kv
,
attn_batches
,
flags
);
// Apply Dropout Mask and Scale by Dropout Probability
apex_masked_scale_cuda
<
at
::
Half
,
float
,
uint32_t
>
(
static_cast
<
at
::
Half
const
*>
(
matmul2_grads
.
data_ptr
()),
static_cast
<
at
::
Half
*>
(
matmul2_grads
.
data_ptr
()),
static_cast
<
uint8_t
const
*>
(
dropout_mask
.
data_ptr
()),
dropout_elems
,
(
1.0
/
(
1.0
-
dropout_prob
)));
// Softmax Grad
bool
softmax_success
=
false
;
softmax_success
=
dispatch_softmax_backward
<
half
,
half
,
float
>
(
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
reinterpret_cast
<
half
const
*>
(
softmax_results
.
data_ptr
()),
k_seq_len
,
k_seq_len
,
attn_batches
*
q_seq_len
);
assert
(
softmax_success
);
// Matmul1 Dgrad1
gemm_switch_fp32accum
(
a_layout_n
,
b_layout_n
,
head_dim
,
q_seq_len
,
k_seq_len
,
scale
,
k_lin_results_ptr
,
lead_dim_kv
,
batch_stride_kv
,
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
k_seq_len
,
k_seq_len
*
q_seq_len
,
beta
,
q_lin_grads_ptr
,
lead_dim_q
,
batch_stride_q
,
q_lin_grads_ptr
,
lead_dim_q
,
batch_stride_q
,
attn_batches
,
flags
);
// Matmul1 Dgrad2
gemm_switch_fp32accum
(
a_layout_n
,
b_layout_t
,
head_dim
,
k_seq_len
,
q_seq_len
,
scale
,
q_lin_results_ptr
,
lead_dim_q
,
batch_stride_q
,
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
k_seq_len
,
k_seq_len
*
q_seq_len
,
beta
,
k_lin_grads_ptr
,
lead_dim_kv
,
batch_stride_kv
,
k_lin_grads_ptr
,
lead_dim_kv
,
batch_stride_kv
,
attn_batches
,
flags
);
// Input Linear Q Dgrad
TORCH_CUDABLAS_CHECK
(
rocblas_gemm_ex
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
embed_dim
,
batches_q
,
output_lin_q_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
input_weights_q
.
data_ptr
()),
rocblas_datatype_f16_r
/*a_type*/
,
embed_dim
,
static_cast
<
const
void
*>
(
q_lin_grads_ptr
),
rocblas_datatype_f16_r
/*b_type*/
,
output_lin_q_dim
,
static_cast
<
const
void
*>
(
&
beta
),
//static_cast<void*>(input_q_grads.data_ptr()),
static_cast
<
void
*>
(
input_lin_q_grads
.
data_ptr
()),
rocblas_datatype_f16_r
/*c_type*/
,
embed_dim
,
static_cast
<
void
*>
(
input_lin_q_grads
.
data_ptr
()),
rocblas_datatype_f16_r
/*d_type*/
,
embed_dim
,
rocblas_datatype_f32_r
/*compute_type*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
// Input Linear Q Wgrad
TORCH_CUDABLAS_CHECK
(
rocblas_gemm_ex
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
embed_dim
,
output_lin_q_dim
,
batches_q
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
inputs_q
.
data_ptr
()),
rocblas_datatype_f16_r
/*a_type*/
,
embed_dim
,
static_cast
<
const
void
*>
(
q_lin_grads_ptr
),
rocblas_datatype_f16_r
/*b_type*/
,
output_lin_q_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
input_weight_q_grads
.
data_ptr
()),
rocblas_datatype_f16_r
/*c_type*/
,
embed_dim
,
static_cast
<
void
*>
(
input_weight_q_grads
.
data_ptr
()),
rocblas_datatype_f16_r
/*d_type*/
,
embed_dim
,
rocblas_datatype_f32_r
/*compute_type*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
// Input Linear KV Dgrad
TORCH_CUDABLAS_CHECK
(
rocblas_gemm_ex
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
embed_dim
,
batches_kv
,
output_lin_kv_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
input_weights_kv
.
data_ptr
()),
rocblas_datatype_f16_r
/*a_type*/
,
embed_dim
,
static_cast
<
const
void
*>
(
k_lin_grads_ptr
),
rocblas_datatype_f16_r
/*b_type*/
,
output_lin_kv_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
input_kv_grads
.
data_ptr
()),
rocblas_datatype_f16_r
/*c_type*/
,
embed_dim
,
static_cast
<
void
*>
(
input_kv_grads
.
data_ptr
()),
rocblas_datatype_f16_r
/*d_type*/
,
embed_dim
,
rocblas_datatype_f32_r
/*compute_type*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
// Input Linear KV Wgrad
TORCH_CUDABLAS_CHECK
(
rocblas_gemm_ex
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
embed_dim
,
output_lin_kv_dim
,
batches_kv
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
inputs_kv
.
data_ptr
()),
rocblas_datatype_f16_r
/*a_type*/
,
embed_dim
,
static_cast
<
const
void
*>
(
k_lin_grads_ptr
),
rocblas_datatype_f16_r
/*b_type*/
,
output_lin_kv_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
input_weight_kv_grads
.
data_ptr
()),
rocblas_datatype_f16_r
/*c_type*/
,
embed_dim
,
static_cast
<
void
*>
(
input_weight_kv_grads
.
data_ptr
()),
rocblas_datatype_f16_r
/*d_type*/
,
embed_dim
,
rocblas_datatype_f32_r
/*compute_type*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
// Fused Layer Norm Bwd with Residual Add
HostLayerNormGradient
<
half
,
float
>
(
static_cast
<
const
half
*>
(
input_lin_q_grads
.
data_ptr
()),
static_cast
<
half
const
*>
(
output_grads
.
data_ptr
()),
static_cast
<
const
float
*>
(
lyr_nrm_mean
.
data_ptr
()),
static_cast
<
const
float
*>
(
lyr_nrm_invvar
.
data_ptr
()),
inputs_q
,
static_cast
<
int
>
(
batches_q
),
// n1
static_cast
<
int
>
(
embed_dim
),
// n2
static_cast
<
const
half
*>
(
lyr_nrm_gamma_weights
.
data_ptr
()),
static_cast
<
const
half
*>
(
lyr_nrm_beta_weights
.
data_ptr
()),
1.0e-5
,
static_cast
<
half
*>
(
input_q_grads
.
data_ptr
()),
static_cast
<
half
*>
(
lyr_nrm_gamma_grads
.
data_ptr
()),
static_cast
<
half
*>
(
lyr_nrm_beta_grads
.
data_ptr
())
);
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return
{
input_q_grads
,
input_kv_grads
,
lyr_nrm_gamma_grads
,
lyr_nrm_beta_grads
,
input_weight_q_grads
,
input_weight_kv_grads
,
output_weight_grads
};
}
}
// end namespace rocblas_gemmex
}
// end namespace encdec_norm_add
}
// end namespace multihead_attn
\ No newline at end of file
apex/contrib/csrc/multihead_attn/layer_norm.cuh
deleted
100644 → 0
View file @
2a4864d5
#pragma once
#include <cuda.h>
#include <cuda_runtime.h>
#include <ATen/ATen.h>
#include <ATen/cuda/DeviceUtils.cuh>
namespace
{
template
<
typename
U
>
__device__
void
cuWelfordOnlineSum
(
const
U
curr
,
U
&
mu
,
U
&
sigma2
,
U
&
count
)
{
count
=
count
+
U
(
1
);
U
delta
=
curr
-
mu
;
U
lmean
=
mu
+
delta
/
count
;
mu
=
lmean
;
U
delta2
=
curr
-
lmean
;
sigma2
=
sigma2
+
delta
*
delta2
;
}
template
<
typename
U
>
__device__
void
cuChanOnlineSum
(
const
U
muB
,
const
U
sigma2B
,
const
U
countB
,
U
&
mu
,
U
&
sigma2
,
U
&
count
)
{
U
delta
=
muB
-
mu
;
U
nA
=
count
;
U
nB
=
countB
;
count
=
count
+
countB
;
U
nX
=
count
;
if
(
nX
>
U
(
0
))
{
nA
=
nA
/
nX
;
nB
=
nB
/
nX
;
mu
=
nA
*
mu
+
nB
*
muB
;
sigma2
=
sigma2
+
sigma2B
+
delta
*
delta
*
nA
*
nB
*
nX
;
}
else
{
mu
=
U
(
0
);
sigma2
=
U
(
0
);
}
}
template
<
typename
T
,
typename
U
>
__device__
void
cuWelfordMuSigma2
(
const
T
*
__restrict__
vals
,
const
int
n1
,
const
int
n2
,
const
int
i1
,
U
&
mu
,
U
&
sigma2
,
U
*
buf
)
{
// Assumptions:
// 1) blockDim.x == warpSize
// 2) Tensor is contiguous
// 3) 2*blockDim.y*sizeof(U)+blockDim.y*sizeof(int) shared memory available.
//
// compute variance and mean over n2
U
count
=
U
(
0
);
mu
=
U
(
0
);
sigma2
=
U
(
0
);
if
(
i1
<
n1
)
{
// one warp normalizes one n1 index,
// synchronization is implicit
// initialize with standard Welford algorithm
const
int
numx
=
blockDim
.
x
*
blockDim
.
y
;
const
int
thrx
=
threadIdx
.
x
+
threadIdx
.
y
*
blockDim
.
x
;
const
T
*
lvals
=
vals
+
i1
*
n2
;
int
l
=
4
*
thrx
;
for
(;
l
+
3
<
n2
;
l
+=
4
*
numx
)
{
for
(
int
k
=
0
;
k
<
4
;
++
k
)
{
U
curr
=
static_cast
<
U
>
(
lvals
[
l
+
k
]);
cuWelfordOnlineSum
<
U
>
(
curr
,
mu
,
sigma2
,
count
);
}
}
for
(;
l
<
n2
;
++
l
)
{
U
curr
=
static_cast
<
U
>
(
lvals
[
l
]);
cuWelfordOnlineSum
<
U
>
(
curr
,
mu
,
sigma2
,
count
);
}
// intra-warp reductions
for
(
int
l
=
0
;
l
<=
4
;
++
l
)
{
int
srcLaneB
=
(
threadIdx
.
x
+
(
1
<<
l
))
&
31
;
U
muB
=
WARP_SHFL
(
mu
,
srcLaneB
,
32
);
U
countB
=
WARP_SHFL
(
count
,
srcLaneB
,
32
);
U
sigma2B
=
WARP_SHFL
(
sigma2
,
srcLaneB
,
32
);
cuChanOnlineSum
<
U
>
(
muB
,
sigma2B
,
countB
,
mu
,
sigma2
,
count
);
}
// threadIdx.x == 0 has correct values for each warp
// inter-warp reductions
if
(
blockDim
.
y
>
1
)
{
U
*
ubuf
=
(
U
*
)
buf
;
U
*
ibuf
=
(
U
*
)(
ubuf
+
blockDim
.
y
);
for
(
int
offset
=
blockDim
.
y
/
2
;
offset
>
0
;
offset
/=
2
)
{
// upper half of warps write to shared
if
(
threadIdx
.
x
==
0
&&
threadIdx
.
y
>=
offset
&&
threadIdx
.
y
<
2
*
offset
)
{
const
int
wrt_y
=
threadIdx
.
y
-
offset
;
ubuf
[
2
*
wrt_y
]
=
mu
;
ubuf
[
2
*
wrt_y
+
1
]
=
sigma2
;
ibuf
[
wrt_y
]
=
count
;
}
__syncthreads
();
// lower half merges
if
(
threadIdx
.
x
==
0
&&
threadIdx
.
y
<
offset
)
{
U
muB
=
ubuf
[
2
*
threadIdx
.
y
];
U
sigma2B
=
ubuf
[
2
*
threadIdx
.
y
+
1
];
U
countB
=
ibuf
[
threadIdx
.
y
];
cuChanOnlineSum
<
U
>
(
muB
,
sigma2B
,
countB
,
mu
,
sigma2
,
count
);
}
__syncthreads
();
}
// threadIdx.x = 0 && threadIdx.y == 0 only thread that has correct values
if
(
threadIdx
.
x
==
0
&&
threadIdx
.
y
==
0
)
{
ubuf
[
0
]
=
mu
;
ubuf
[
1
]
=
sigma2
;
}
__syncthreads
();
mu
=
ubuf
[
0
];
sigma2
=
ubuf
[
1
]
/
U
(
n2
);
// don't care about final value of count, we know count == n2
}
else
{
mu
=
WARP_SHFL
(
mu
,
0
,
32
);
sigma2
=
WARP_SHFL
(
sigma2
/
U
(
n2
),
0
,
32
);
}
}
}
template
<
>
__device__
void
cuWelfordMuSigma2
(
const
at
::
Half
*
__restrict__
vals
,
const
int
n1
,
const
int
n2
,
const
int
i1
,
float
&
mu
,
float
&
sigma2
,
float
*
buf
)
{
// Assumptions:
// 1) blockDim.x == warpSize
// 2) Tensor is contiguous
// 3) 2*blockDim.y*sizeof(U)+blockDim.y*sizeof(int) shared memory available.
//
// compute variance and mean over n2
float
count
=
0.0
f
;
mu
=
float
(
0
);
sigma2
=
float
(
0
);
if
(
i1
<
n1
)
{
// one warp normalizes one n1 index,
// synchronization is implicit
// initialize with standard Welford algorithm
const
int
numx
=
blockDim
.
x
*
blockDim
.
y
;
const
int
thrx
=
threadIdx
.
x
+
threadIdx
.
y
*
blockDim
.
x
;
const
at
::
Half
*
lvals
=
vals
+
i1
*
n2
;
int
l
=
8
*
thrx
;
if
((((
size_t
)
lvals
)
&
3
)
!=
0
)
{
// 16 bit alignment
// first thread consumes first point
if
(
thrx
==
0
)
{
float
curr
=
static_cast
<
float
>
(
lvals
[
0
]);
cuWelfordOnlineSum
(
curr
,
mu
,
sigma2
,
count
);
}
++
l
;
}
// at this point, lvals[l] are 32 bit aligned for all threads.
for
(;
l
+
7
<
n2
;
l
+=
8
*
numx
)
{
for
(
int
k
=
0
;
k
<
8
;
k
+=
2
)
{
float2
curr
=
__half22float2
(
*
((
__half2
*
)(
lvals
+
l
+
k
)));
cuWelfordOnlineSum
(
curr
.
x
,
mu
,
sigma2
,
count
);
cuWelfordOnlineSum
(
curr
.
y
,
mu
,
sigma2
,
count
);
}
}
for
(;
l
<
n2
;
++
l
)
{
float
curr
=
static_cast
<
float
>
(
lvals
[
l
]);
cuWelfordOnlineSum
(
curr
,
mu
,
sigma2
,
count
);
}
// intra-warp reductions
for
(
int
l
=
0
;
l
<=
4
;
++
l
)
{
int
srcLaneB
=
(
threadIdx
.
x
+
(
1
<<
l
))
&
31
;
float
muB
=
WARP_SHFL
(
mu
,
srcLaneB
,
32
);
float
countB
=
WARP_SHFL
(
count
,
srcLaneB
,
32
);
float
sigma2B
=
WARP_SHFL
(
sigma2
,
srcLaneB
,
32
);
cuChanOnlineSum
(
muB
,
sigma2B
,
countB
,
mu
,
sigma2
,
count
);
}
// threadIdx.x == 0 has correct values for each warp
// inter-warp reductions
if
(
blockDim
.
y
>
1
)
{
float
*
ubuf
=
(
float
*
)
buf
;
float
*
ibuf
=
(
float
*
)(
ubuf
+
blockDim
.
y
);
for
(
int
offset
=
blockDim
.
y
/
2
;
offset
>
0
;
offset
/=
2
)
{
// upper half of warps write to shared
if
(
threadIdx
.
x
==
0
&&
threadIdx
.
y
>=
offset
&&
threadIdx
.
y
<
2
*
offset
)
{
const
int
wrt_y
=
threadIdx
.
y
-
offset
;
ubuf
[
2
*
wrt_y
]
=
mu
;
ubuf
[
2
*
wrt_y
+
1
]
=
sigma2
;
ibuf
[
wrt_y
]
=
count
;
}
__syncthreads
();
// lower half merges
if
(
threadIdx
.
x
==
0
&&
threadIdx
.
y
<
offset
)
{
float
muB
=
ubuf
[
2
*
threadIdx
.
y
];
float
sigma2B
=
ubuf
[
2
*
threadIdx
.
y
+
1
];
float
countB
=
ibuf
[
threadIdx
.
y
];
cuChanOnlineSum
(
muB
,
sigma2B
,
countB
,
mu
,
sigma2
,
count
);
}
__syncthreads
();
}
// threadIdx.x = 0 && threadIdx.y == 0 only thread that has correct values
if
(
threadIdx
.
x
==
0
&&
threadIdx
.
y
==
0
)
{
ubuf
[
0
]
=
mu
;
ubuf
[
1
]
=
sigma2
;
}
__syncthreads
();
mu
=
ubuf
[
0
];
sigma2
=
ubuf
[
1
]
/
float
(
n2
);
// don't care about final value of count, we know count == n2
}
else
{
mu
=
WARP_SHFL
(
mu
,
0
,
32
);
sigma2
=
WARP_SHFL
(
sigma2
/
float
(
n2
),
0
,
32
);
}
}
}
template
<
typename
U
>
U
rsqrt
(
U
v
)
{
return
U
(
1
)
/
sqrt
(
v
);
}
//template<> float rsqrt(float v) {
// return rsqrtf(v);
//}
#if defined __HIP_PLATFORM_HCC__
__device__
float
rsqrt
(
float
v
)
{
return
rsqrtf
(
v
);
}
#else
template
<
>
float
rsqrt
(
float
v
)
{
return
rsqrtf
(
v
);
}
#endif
template
<
>
double
rsqrt
(
double
v
)
{
return
rsqrt
(
v
);
}
// template <typename U> __device__ U rsqrt(U v) { return U(1) / sqrt(v); }
// template <> __device__ float rsqrt(float v) { return rsqrtf(v); }
// template <> __device__ double rsqrt(double v) { return rsqrt(v); }
// This is the un-specialized struct. Note that we prevent instantiation of
// this struct by putting an undefined symbol in the function body so it won't
// compile.
// template <typename T>
// struct SharedMemory
// {
// // Ensure that we won't compile any un-specialized types
// __device__ T *getPointer()
// {
// extern __device__ void error(void);
// error();
// return NULL;
// }
// };
// https://github.com/NVIDIA/apex/issues/246
template
<
typename
T
>
struct
SharedMemory
;
template
<
>
struct
SharedMemory
<
float
>
{
__device__
float
*
getPointer
()
{
extern
__shared__
float
s_float
[];
return
s_float
;
}
};
template
<
>
struct
SharedMemory
<
double
>
{
__device__
double
*
getPointer
()
{
extern
__shared__
double
s_double
[];
return
s_double
;
}
};
template
<
typename
T
,
typename
U
>
__global__
void
cuApplyLayerNorm
(
T
*
__restrict__
output_vals
,
U
*
__restrict__
mean
,
U
*
__restrict__
invvar
,
const
T
*
__restrict__
vals
,
const
int
n1
,
const
int
n2
,
const
U
epsilon
,
const
T
*
__restrict__
gamma
,
const
T
*
__restrict__
beta
)
{
// Assumptions:
// 1) blockDim.x == warpSize
// 2) Tensors are contiguous
//
for
(
int
i1
=
blockIdx
.
y
;
i1
<
n1
;
i1
+=
gridDim
.
y
)
{
SharedMemory
<
U
>
shared
;
U
*
buf
=
shared
.
getPointer
();
U
mu
,
sigma2
;
cuWelfordMuSigma2
(
vals
,
n1
,
n2
,
i1
,
mu
,
sigma2
,
buf
);
const
T
*
lvals
=
vals
+
i1
*
n2
;
T
*
ovals
=
output_vals
+
i1
*
n2
;
U
c_invvar
=
rsqrt
(
sigma2
+
epsilon
);
const
int
numx
=
blockDim
.
x
*
blockDim
.
y
;
const
int
thrx
=
threadIdx
.
x
+
threadIdx
.
y
*
blockDim
.
x
;
if
(
gamma
!=
NULL
&&
beta
!=
NULL
)
{
for
(
int
i
=
thrx
;
i
<
n2
;
i
+=
numx
)
{
U
curr
=
static_cast
<
U
>
(
lvals
[
i
]);
ovals
[
i
]
=
gamma
[
i
]
*
static_cast
<
T
>
(
c_invvar
*
(
curr
-
mu
))
+
beta
[
i
];
}
}
else
{
for
(
int
i
=
thrx
;
i
<
n2
;
i
+=
numx
)
{
U
curr
=
static_cast
<
U
>
(
lvals
[
i
]);
ovals
[
i
]
=
static_cast
<
T
>
(
c_invvar
*
(
curr
-
mu
));
}
}
if
(
threadIdx
.
x
==
0
&&
threadIdx
.
y
==
0
)
{
mean
[
i1
]
=
mu
;
invvar
[
i1
]
=
c_invvar
;
}
}
}
template
<
typename
T
,
typename
U
>
__device__
void
cuLoadWriteStridedInputs
(
const
int
i1_block
,
const
int
thr_load_row_off
,
const
int
thr_load_col_off
,
const
int
i2_off
,
const
int
row_stride
,
U
*
warp_buf1
,
U
*
warp_buf2
,
const
T
*
input
,
const
T
*
dout
,
const
int
i1_end
,
const
int
n2
,
const
U
*
__restrict__
mean
,
const
U
*
__restrict__
invvar
)
{
int
i1
=
i1_block
+
thr_load_row_off
;
if
(
i1
<
i1_end
)
{
U
curr_mean
=
mean
[
i1
];
U
curr_invvar
=
invvar
[
i1
];
for
(
int
k
=
0
;
k
<
blockDim
.
y
;
++
k
)
{
int
i2
=
i2_off
+
k
;
int
load_idx
=
i1
*
n2
+
i2
;
int
write_idx
=
thr_load_row_off
*
row_stride
+
thr_load_col_off
+
k
;
if
(
i2
<
n2
)
{
U
curr_input
=
static_cast
<
U
>
(
input
[
load_idx
]);
U
curr_dout
=
static_cast
<
U
>
(
dout
[
load_idx
]);
warp_buf1
[
write_idx
]
=
curr_dout
;
warp_buf2
[
write_idx
]
=
curr_dout
*
(
curr_input
-
curr_mean
)
*
curr_invvar
;
}
else
{
warp_buf1
[
write_idx
]
=
U
(
0
);
warp_buf2
[
write_idx
]
=
U
(
0
);
}
}
}
else
{
for
(
int
k
=
0
;
k
<
blockDim
.
y
;
++
k
)
{
int
write_idx
=
thr_load_row_off
*
row_stride
+
thr_load_col_off
+
k
;
warp_buf1
[
write_idx
]
=
U
(
0
);
warp_buf2
[
write_idx
]
=
U
(
0
);
}
}
}
template
<
typename
T
,
typename
U
>
__device__
void
cuLoadAddStridedInputs
(
const
int
i1_block
,
const
int
thr_load_row_off
,
const
int
thr_load_col_off
,
const
int
i2_off
,
const
int
row_stride
,
U
*
warp_buf1
,
U
*
warp_buf2
,
const
T
*
input
,
const
T
*
dout
,
const
int
i1_end
,
const
int
n2
,
const
U
*
__restrict__
mean
,
const
U
*
__restrict__
invvar
)
{
int
i1
=
i1_block
+
thr_load_row_off
;
if
(
i1
<
i1_end
)
{
U
curr_mean
=
mean
[
i1
];
U
curr_invvar
=
invvar
[
i1
];
for
(
int
k
=
0
;
k
<
blockDim
.
y
;
++
k
)
{
int
i2
=
i2_off
+
k
;
int
load_idx
=
i1
*
n2
+
i2
;
int
write_idx
=
thr_load_row_off
*
row_stride
+
thr_load_col_off
+
k
;
if
(
i2
<
n2
)
{
U
curr_input
=
static_cast
<
U
>
(
input
[
load_idx
]);
U
curr_dout
=
static_cast
<
U
>
(
dout
[
load_idx
]);
warp_buf1
[
write_idx
]
+=
curr_dout
;
warp_buf2
[
write_idx
]
+=
curr_dout
*
(
curr_input
-
curr_mean
)
*
curr_invvar
;
}
}
}
}
template
<
typename
T
,
typename
U
>
__global__
void
cuComputePartGradGammaBeta
(
const
T
*
__restrict__
dout
,
const
T
*
__restrict__
input
,
const
int
n1
,
const
int
n2
,
const
U
*
__restrict__
mean
,
const
U
*
__restrict__
invvar
,
U
epsilon
,
U
*
part_grad_gamma
,
U
*
part_grad_beta
)
{
const
int
numsegs_n1
=
(
n1
+
blockDim
.
y
*
blockDim
.
y
-
1
)
/
(
blockDim
.
y
*
blockDim
.
y
);
const
int
segs_per_block
=
(
numsegs_n1
+
gridDim
.
y
-
1
)
/
gridDim
.
y
;
const
int
i1_beg
=
blockIdx
.
y
*
segs_per_block
*
blockDim
.
y
*
blockDim
.
y
;
const
int
i1_beg_plus_one
=
(
blockIdx
.
y
+
1
)
*
segs_per_block
*
blockDim
.
y
*
blockDim
.
y
;
const
int
i1_end
=
i1_beg_plus_one
<
n1
?
i1_beg_plus_one
:
n1
;
const
int
row_stride
=
blockDim
.
x
+
1
;
const
int
thr_load_col_off
=
(
threadIdx
.
x
*
blockDim
.
y
)
&
(
blockDim
.
x
-
1
);
const
int
thr_load_row_off
=
(
threadIdx
.
x
*
blockDim
.
y
)
/
blockDim
.
x
+
threadIdx
.
y
*
blockDim
.
y
;
const
int
i2_off
=
blockIdx
.
x
*
blockDim
.
x
+
thr_load_col_off
;
SharedMemory
<
U
>
shared
;
U
*
buf
=
shared
.
getPointer
();
// buf has at least blockDim.x * blockDim.y *
// blockDim.y + (blockDim.y -
// 1)*(blockDim.x/blockDim.y) elements
U
*
warp_buf1
=
(
U
*
)
buf
;
U
*
warp_buf2
=
warp_buf1
+
blockDim
.
y
*
blockDim
.
y
*
row_stride
;
// compute partial sums from strided inputs
// do this to increase number of loads in flight
cuLoadWriteStridedInputs
(
i1_beg
,
thr_load_row_off
,
thr_load_col_off
,
i2_off
,
row_stride
,
warp_buf1
,
warp_buf2
,
input
,
dout
,
i1_end
,
n2
,
mean
,
invvar
);
for
(
int
i1_block
=
i1_beg
+
blockDim
.
y
*
blockDim
.
y
;
i1_block
<
i1_end
;
i1_block
+=
blockDim
.
y
*
blockDim
.
y
)
{
cuLoadAddStridedInputs
(
i1_block
,
thr_load_row_off
,
thr_load_col_off
,
i2_off
,
row_stride
,
warp_buf1
,
warp_buf2
,
input
,
dout
,
i1_end
,
n2
,
mean
,
invvar
);
}
__syncthreads
();
// inter-warp reductions
// sum within each warp
U
acc1
=
U
(
0
);
U
acc2
=
U
(
0
);
for
(
int
k
=
0
;
k
<
blockDim
.
y
;
++
k
)
{
int
row1
=
threadIdx
.
y
+
k
*
blockDim
.
y
;
int
idx1
=
row1
*
row_stride
+
threadIdx
.
x
;
acc1
+=
warp_buf1
[
idx1
];
acc2
+=
warp_buf2
[
idx1
];
}
warp_buf1
[
threadIdx
.
y
*
row_stride
+
threadIdx
.
x
]
=
acc1
;
warp_buf2
[
threadIdx
.
y
*
row_stride
+
threadIdx
.
x
]
=
acc2
;
__syncthreads
();
// sum all warps
for
(
int
offset
=
blockDim
.
y
/
2
;
offset
>
1
;
offset
/=
2
)
{
if
(
threadIdx
.
y
<
offset
)
{
int
row1
=
threadIdx
.
y
;
int
row2
=
threadIdx
.
y
+
offset
;
int
idx1
=
row1
*
row_stride
+
threadIdx
.
x
;
int
idx2
=
row2
*
row_stride
+
threadIdx
.
x
;
warp_buf1
[
idx1
]
+=
warp_buf1
[
idx2
];
warp_buf2
[
idx1
]
+=
warp_buf2
[
idx2
];
}
__syncthreads
();
}
int
i2
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
threadIdx
.
y
==
0
&&
i2
<
n2
)
{
int
row1
=
threadIdx
.
y
;
int
row2
=
threadIdx
.
y
+
1
;
int
idx1
=
row1
*
row_stride
+
threadIdx
.
x
;
int
idx2
=
row2
*
row_stride
+
threadIdx
.
x
;
part_grad_beta
[
blockIdx
.
y
*
n2
+
i2
]
=
warp_buf1
[
idx1
]
+
warp_buf1
[
idx2
];
part_grad_gamma
[
blockIdx
.
y
*
n2
+
i2
]
=
warp_buf2
[
idx1
]
+
warp_buf2
[
idx2
];
}
}
template
<
typename
T
,
typename
U
>
__global__
void
cuComputeGradGammaBeta
(
const
U
*
part_grad_gamma
,
const
U
*
part_grad_beta
,
const
int
part_size
,
const
int
n1
,
const
int
n2
,
T
*
grad_gamma
,
T
*
grad_beta
)
{
// sum partial gradients for gamma and beta
SharedMemory
<
U
>
shared
;
U
*
buf
=
shared
.
getPointer
();
int
i2
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
i2
<
n2
)
{
// each warp does sequential reductions until reduced part_size is num_warps
int
num_warp_reductions
=
part_size
/
blockDim
.
y
;
U
sum_gamma
=
U
(
0
);
U
sum_beta
=
U
(
0
);
const
U
*
part_grad_gamma_ptr
=
part_grad_gamma
+
threadIdx
.
y
*
num_warp_reductions
*
n2
+
i2
;
const
U
*
part_grad_beta_ptr
=
part_grad_beta
+
threadIdx
.
y
*
num_warp_reductions
*
n2
+
i2
;
for
(
int
warp_offset
=
0
;
warp_offset
<
num_warp_reductions
;
++
warp_offset
)
{
sum_gamma
+=
part_grad_gamma_ptr
[
warp_offset
*
n2
];
sum_beta
+=
part_grad_beta_ptr
[
warp_offset
*
n2
];
}
// inter-warp reductions
const
int
nbsize3
=
blockDim
.
x
*
blockDim
.
y
/
2
;
for
(
int
offset
=
blockDim
.
y
/
2
;
offset
>=
1
;
offset
/=
2
)
{
// top half write to shared memory
if
(
threadIdx
.
y
>=
offset
&&
threadIdx
.
y
<
2
*
offset
)
{
const
int
write_idx
=
(
threadIdx
.
y
-
offset
)
*
blockDim
.
x
+
threadIdx
.
x
;
buf
[
write_idx
]
=
sum_gamma
;
buf
[
write_idx
+
nbsize3
]
=
sum_beta
;
}
__syncthreads
();
// bottom half sums
if
(
threadIdx
.
y
<
offset
)
{
const
int
read_idx
=
threadIdx
.
y
*
blockDim
.
x
+
threadIdx
.
x
;
sum_gamma
+=
buf
[
read_idx
];
sum_beta
+=
buf
[
read_idx
+
nbsize3
];
}
__syncthreads
();
}
// write out fully summed gradients
if
(
threadIdx
.
y
==
0
)
{
grad_gamma
[
i2
]
=
sum_gamma
;
grad_beta
[
i2
]
=
sum_beta
;
}
}
}
template
<
typename
T
,
typename
U
>
__global__
void
cuComputeGradInput
(
const
T
*
__restrict__
dout
,
const
T
*
__restrict__
dout_resid
,
const
T
*
__restrict__
input
,
const
int
n1
,
const
int
n2
,
const
U
*
__restrict__
mean
,
const
U
*
__restrict__
invvar
,
U
epsilon
,
const
T
*
gamma
,
T
*
grad_input
)
{
for
(
int
i1
=
blockIdx
.
y
;
i1
<
n1
;
i1
+=
gridDim
.
y
)
{
U
sum_loss1
=
U
(
0
);
U
sum_loss2
=
U
(
0
);
const
U
c_mean
=
mean
[
i1
];
const
U
c_invvar
=
invvar
[
i1
];
const
T
*
k_input
=
input
+
i1
*
n2
;
const
T
*
k_dout
=
dout
+
i1
*
n2
;
const
T
*
k_dout_resid
=
dout_resid
+
i1
*
n2
;
const
int
numx
=
blockDim
.
x
*
blockDim
.
y
;
const
int
thrx
=
threadIdx
.
x
+
threadIdx
.
y
*
blockDim
.
x
;
if
(
gamma
!=
NULL
)
{
int
l
=
4
*
thrx
;
for
(;
l
+
3
<
n2
;
l
+=
4
*
numx
)
{
for
(
int
k
=
0
;
k
<
4
;
++
k
)
{
const
U
c_h
=
static_cast
<
U
>
(
k_input
[
l
+
k
]);
const
U
c_loss
=
static_cast
<
U
>
(
k_dout
[
l
+
k
]);
sum_loss1
+=
c_loss
*
static_cast
<
U
>
(
gamma
[
l
+
k
]);
sum_loss2
+=
c_loss
*
static_cast
<
U
>
(
gamma
[
l
+
k
])
*
(
c_h
-
c_mean
)
*
c_invvar
;
}
}
for
(;
l
<
n2
;
++
l
)
{
const
U
c_h
=
static_cast
<
U
>
(
k_input
[
l
]);
const
U
c_loss
=
static_cast
<
U
>
(
k_dout
[
l
]);
sum_loss1
+=
c_loss
*
static_cast
<
U
>
(
gamma
[
l
]);
sum_loss2
+=
c_loss
*
static_cast
<
U
>
(
gamma
[
l
])
*
(
c_h
-
c_mean
)
*
c_invvar
;
}
}
else
{
int
l
=
4
*
thrx
;
for
(;
l
+
3
<
n2
;
l
+=
4
*
numx
)
{
for
(
int
k
=
0
;
k
<
4
;
++
k
)
{
const
U
c_h
=
static_cast
<
U
>
(
k_input
[
l
+
k
]);
const
U
c_loss
=
static_cast
<
U
>
(
k_dout
[
l
+
k
]);
sum_loss1
+=
c_loss
;
sum_loss2
+=
c_loss
*
(
c_h
-
c_mean
)
*
c_invvar
;
}
}
for
(;
l
<
n2
;
++
l
)
{
const
U
c_h
=
static_cast
<
U
>
(
k_input
[
l
]);
const
U
c_loss
=
static_cast
<
U
>
(
k_dout
[
l
]);
sum_loss1
+=
c_loss
;
sum_loss2
+=
c_loss
*
(
c_h
-
c_mean
)
*
c_invvar
;
}
}
// intra-warp reductions
for
(
int
mask
=
blockDim
.
x
/
2
;
mask
>
0
;
mask
/=
2
)
{
sum_loss1
+=
WARP_SHFL_XOR
(
sum_loss1
,
mask
,
32
);
sum_loss2
+=
WARP_SHFL_XOR
(
sum_loss2
,
mask
,
32
);
}
// inter-warp reductions
if
(
blockDim
.
y
>
1
)
{
SharedMemory
<
U
>
shared
;
U
*
buf
=
shared
.
getPointer
();
for
(
int
offset
=
blockDim
.
y
/
2
;
offset
>
0
;
offset
/=
2
)
{
// upper half of warps write to shared
if
(
threadIdx
.
y
>=
offset
&&
threadIdx
.
y
<
2
*
offset
)
{
const
int
wrt_i
=
(
threadIdx
.
y
-
offset
)
*
blockDim
.
x
+
threadIdx
.
x
;
buf
[
2
*
wrt_i
]
=
sum_loss1
;
buf
[
2
*
wrt_i
+
1
]
=
sum_loss2
;
}
__syncthreads
();
// lower half merges
if
(
threadIdx
.
y
<
offset
)
{
const
int
read_i
=
threadIdx
.
y
*
blockDim
.
x
+
threadIdx
.
x
;
sum_loss1
+=
buf
[
2
*
read_i
];
sum_loss2
+=
buf
[
2
*
read_i
+
1
];
}
__syncthreads
();
}
if
(
threadIdx
.
y
==
0
)
{
buf
[
2
*
threadIdx
.
x
]
=
sum_loss1
;
buf
[
2
*
threadIdx
.
x
+
1
]
=
sum_loss2
;
}
__syncthreads
();
if
(
threadIdx
.
y
!=
0
)
{
sum_loss1
=
buf
[
2
*
threadIdx
.
x
];
sum_loss2
=
buf
[
2
*
threadIdx
.
x
+
1
];
}
}
// all threads now have the two sums over l
U
fH
=
(
U
)
n2
;
U
term1
=
(
U
(
1
)
/
fH
)
*
c_invvar
;
T
*
k_grad_input
=
grad_input
+
i1
*
n2
;
if
(
gamma
!=
NULL
)
{
for
(
int
l
=
thrx
;
l
<
n2
;
l
+=
numx
)
{
const
U
c_h
=
static_cast
<
U
>
(
k_input
[
l
]);
const
U
c_loss
=
static_cast
<
U
>
(
k_dout
[
l
]);
const
T
c_resid
=
static_cast
<
T
>
(
k_dout_resid
[
l
]);
U
f_grad_input
=
fH
*
c_loss
*
static_cast
<
U
>
(
gamma
[
l
]);
f_grad_input
-=
sum_loss1
;
f_grad_input
-=
(
c_h
-
c_mean
)
*
c_invvar
*
sum_loss2
;
f_grad_input
*=
term1
;
k_grad_input
[
l
]
=
static_cast
<
T
>
(
f_grad_input
)
+
c_resid
;
}
}
else
{
for
(
int
l
=
thrx
;
l
<
n2
;
l
+=
numx
)
{
const
U
c_h
=
static_cast
<
U
>
(
k_input
[
l
]);
const
U
c_loss
=
static_cast
<
U
>
(
k_dout
[
l
]);
const
T
c_resid
=
static_cast
<
T
>
(
k_dout_resid
[
l
]);
U
f_grad_input
=
fH
*
c_loss
;
f_grad_input
-=
sum_loss1
;
f_grad_input
-=
(
c_h
-
c_mean
)
*
c_invvar
*
sum_loss2
;
f_grad_input
*=
term1
;
k_grad_input
[
l
]
=
static_cast
<
T
>
(
f_grad_input
)
+
c_resid
;
}
}
}
}
template
<
typename
T
,
typename
U
>
void
HostApplyLayerNorm
(
T
*
output
,
U
*
mean
,
U
*
invvar
,
const
T
*
input
,
int
n1
,
int
n2
,
double
epsilon
,
const
T
*
gamma
,
const
T
*
beta
)
{
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
const
dim3
threads
(
32
,
4
,
1
);
const
uint64_t
maxGridY
=
at
::
cuda
::
getCurrentDeviceProperties
()
->
maxGridSize
[
1
];
const
dim3
blocks
(
1
,
std
::
min
((
uint64_t
)
n1
,
maxGridY
),
1
);
int
nshared
=
threads
.
y
>
1
?
threads
.
y
*
sizeof
(
U
)
+
(
threads
.
y
/
2
)
*
sizeof
(
U
)
:
0
;
cuApplyLayerNorm
<<<
blocks
,
threads
,
nshared
,
stream
>>>
(
output
,
mean
,
invvar
,
input
,
n1
,
n2
,
U
(
epsilon
),
gamma
,
beta
);
}
template
<
typename
T
,
typename
U
>
void
HostLayerNormGradient
(
const
T
*
dout
,
const
T
*
dout_resid
,
const
U
*
mean
,
const
U
*
invvar
,
const
at
::
Tensor
&
input
,
int
n1
,
int
n2
,
const
T
*
gamma
,
const
T
*
beta
,
double
epsilon
,
T
*
grad_input
,
T
*
grad_gamma
,
T
*
grad_beta
)
{
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
if
(
gamma
!=
NULL
&&
beta
!=
NULL
)
{
// compute grad_gamma(j) and grad_beta(j)
const
int
part_size
=
16
;
const
dim3
threads2
(
32
,
4
,
1
);
const
dim3
blocks2
((
n2
+
threads2
.
x
-
1
)
/
threads2
.
x
,
part_size
,
1
);
const
int
nshared2_a
=
2
*
sizeof
(
U
)
*
threads2
.
y
*
threads2
.
y
*
(
threads2
.
x
+
1
);
const
int
nshared2_b
=
threads2
.
x
*
threads2
.
y
*
sizeof
(
U
);
const
int
nshared2
=
nshared2_a
>
nshared2_b
?
nshared2_a
:
nshared2_b
;
at
::
Tensor
part_grad_gamma
=
at
::
empty
(
{
part_size
,
n2
},
input
.
options
().
dtype
(
input
.
scalar_type
()
==
at
::
ScalarType
::
Half
?
at
::
ScalarType
::
Float
:
input
.
scalar_type
()));
at
::
Tensor
part_grad_beta
=
at
::
empty_like
(
part_grad_gamma
);
cuComputePartGradGammaBeta
<<<
blocks2
,
threads2
,
nshared2
,
stream
>>>
(
dout
,
static_cast
<
T
*>
(
input
.
data_ptr
()),
n1
,
n2
,
mean
,
invvar
,
U
(
epsilon
),
static_cast
<
U
*>
(
part_grad_gamma
.
data_ptr
()),
static_cast
<
U
*>
(
part_grad_beta
.
data_ptr
()));
const
dim3
threads3
(
32
,
8
,
1
);
const
dim3
blocks3
((
n2
+
threads2
.
x
-
1
)
/
threads2
.
x
,
1
,
1
);
const
int
nshared3
=
threads3
.
x
*
threads3
.
y
*
sizeof
(
U
);
cuComputeGradGammaBeta
<<<
blocks3
,
threads3
,
nshared3
,
stream
>>>
(
static_cast
<
U
*>
(
part_grad_gamma
.
data_ptr
()),
static_cast
<
U
*>
(
part_grad_beta
.
data_ptr
()),
part_size
,
n1
,
n2
,
grad_gamma
,
grad_beta
);
}
// compute grad_input
const
uint64_t
maxGridY
=
at
::
cuda
::
getCurrentDeviceProperties
()
->
maxGridSize
[
1
];
const
dim3
blocks1
(
1
,
std
::
min
((
uint64_t
)
n1
,
maxGridY
),
1
);
const
dim3
threads1
(
32
,
4
,
1
);
int
nshared
=
threads1
.
y
>
1
?
threads1
.
y
*
threads1
.
x
*
sizeof
(
U
)
:
0
;
cuComputeGradInput
<<<
blocks1
,
threads1
,
nshared
,
stream
>>>
(
dout
,
dout_resid
,
static_cast
<
T
*>
(
input
.
data_ptr
()),
n1
,
n2
,
mean
,
invvar
,
U
(
epsilon
),
gamma
,
grad_input
);
}
}
// namespace
apex/contrib/csrc/multihead_attn/masked_softmax_dropout_cuda.cu
deleted
100644 → 0
View file @
2a4864d5
#include <iostream>
#include <math.h>
#include <vector>
#include <cuda.h>
#include <cuda_fp16.h>
//#include <cuda_profiler_api.h>
#include <cuda_runtime.h>
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include "dropout.cuh"
#include "softmax.cuh"
namespace
multihead_attn
{
namespace
fused_softmax
{
namespace
mask_softmax_dropout
{
std
::
vector
<
torch
::
Tensor
>
fwd_cuda
(
bool
is_training
,
int
heads
,
torch
::
Tensor
const
&
input
,
const
uint8_t
*
pad_mask
,
float
dropout_prob
)
{
const
int
attn_batches
=
input
.
size
(
0
);
const
int
sequences
=
attn_batches
/
heads
;
const
int
q_seq_len
=
input
.
size
(
1
);
const
int
k_seq_len
=
q_seq_len
;
const
int
dropout_elems
=
attn_batches
*
q_seq_len
*
k_seq_len
;
// There is no reason to use more than one stream as every kernel is
// sequentially dependent
cublasHandle_t
handle
=
at
::
cuda
::
getCurrentCUDABlasHandle
();
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
cublasSetStream
(
handle
,
stream
);
// 3 Intermediate Results + Output (Note: dropout intermediates are generated
// by ATen library code)
auto
act_options
=
input
.
options
().
requires_grad
(
false
);
auto
mask_options
=
act_options
.
dtype
(
torch
::
kUInt8
);
torch
::
Tensor
softmax_results
=
torch
::
empty
({
attn_batches
,
q_seq_len
,
k_seq_len
},
act_options
);
torch
::
Tensor
dropout_results
=
torch
::
empty
({
attn_batches
,
q_seq_len
,
k_seq_len
},
act_options
);
torch
::
Tensor
dropout_mask
=
torch
::
empty
({
attn_batches
,
q_seq_len
,
k_seq_len
},
mask_options
);
// Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax)
void
*
input_ptr
=
static_cast
<
void
*>
(
input
.
data_ptr
());
void
*
softmax_results_ptr
=
static_cast
<
void
*>
(
softmax_results
.
data_ptr
());
// Padded Softmax
bool
softmax_success
=
false
;
if
(
pad_mask
==
nullptr
)
{
softmax_success
=
dispatch_softmax
<
half
,
half
,
float
>
(
reinterpret_cast
<
half
*>
(
softmax_results_ptr
),
reinterpret_cast
<
const
half
*>
(
input_ptr
),
k_seq_len
,
k_seq_len
,
attn_batches
*
q_seq_len
);
}
else
{
softmax_success
=
dispatch_masked_softmax
<
half
,
half
,
float
>
(
reinterpret_cast
<
half
*>
(
softmax_results_ptr
),
reinterpret_cast
<
const
half
*>
(
input_ptr
),
pad_mask
,
k_seq_len
,
k_seq_len
,
attn_batches
*
q_seq_len
,
attn_batches
*
q_seq_len
/
sequences
);
}
if
(
is_training
)
{
// use at:: function so that C++ version generates the same random mask as
// python version
auto
dropout_tuple
=
at
::
_fused_dropout
(
softmax_results
,
1.0
f
-
dropout_prob
);
dropout_results
=
std
::
get
<
0
>
(
dropout_tuple
);
dropout_mask
=
std
::
get
<
1
>
(
dropout_tuple
);
}
// Matmul2
return
{
dropout_results
,
dropout_mask
,
softmax_results
};
}
torch
::
Tensor
bwd_cuda
(
int
heads
,
torch
::
Tensor
const
&
output_grads
,
torch
::
Tensor
const
&
softmax_results
,
torch
::
Tensor
const
&
dropout_mask
,
const
uint8_t
*
padding_mask
,
float
dropout_prob
)
{
const
int
attn_batches
=
output_grads
.
size
(
0
);
const
int
q_seq_len
=
output_grads
.
size
(
1
);
const
int
k_seq_len
=
q_seq_len
;
const
int
dropout_elems
=
attn_batches
*
q_seq_len
*
k_seq_len
;
// TODO: Streams can be used in Backprop but I haven't added more than one
// in my first attempt to create the code
cublasHandle_t
handle
=
at
::
cuda
::
getCurrentCUDABlasHandle
();
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
cublasSetStream
(
handle
,
stream
);
// Output Tensor Allocations
// torch::Tensor input_grads = torch::empty_like(output_grads);
// Apply Dropout Mask and Scale by Dropout Probability
// Softmax Grad
if
(
padding_mask
==
nullptr
)
{
dispatch_masked_scale_softmax_backward_stream
<
half
,
half
,
float
,
false
>
(
static_cast
<
half
*>
(
output_grads
.
data_ptr
()),
static_cast
<
half
*>
(
output_grads
.
data_ptr
()),
reinterpret_cast
<
half
const
*>
(
softmax_results
.
data_ptr
()),
static_cast
<
uint8_t
const
*>
(
dropout_mask
.
data_ptr
()),
1.0
/
(
1.0
-
dropout_prob
),
k_seq_len
,
k_seq_len
,
attn_batches
*
q_seq_len
,
stream
);
}
else
{
dispatch_masked_scale_softmax_backward_masked_out_stream
<
half
,
half
,
float
,
false
>
(
static_cast
<
half
*>
(
output_grads
.
data_ptr
()),
static_cast
<
half
*>
(
output_grads
.
data_ptr
()),
reinterpret_cast
<
half
const
*>
(
softmax_results
.
data_ptr
()),
static_cast
<
uint8_t
const
*>
(
dropout_mask
.
data_ptr
()),
static_cast
<
uint8_t
const
*>
(
padding_mask
),
1.0
/
(
1.0
-
dropout_prob
),
k_seq_len
,
k_seq_len
,
attn_batches
*
q_seq_len
,
heads
,
stream
);
}
// backward pass is completely in-place
return
output_grads
;
}
}
// namespace mask_softmax_dropout
}
// namespace fused_softmax
}
// namespace multihead_attn
apex/contrib/csrc/multihead_attn/multihead_attn_frontend.cpp
deleted
100644 → 0
View file @
2a4864d5
#include <vector>
#include <cuda_fp16.h>
#include <torch/extension.h>
#define CHECK_CUDA(x) \
AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) \
AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
namespace
multihead_attn
{
namespace
fused_softmax
{
namespace
additive_mask_softmax_dropout
{
std
::
vector
<
torch
::
Tensor
>
fwd_cuda
(
bool
is_training
,
int
heads
,
torch
::
Tensor
const
&
input
,
const
half
*
pad_mask
,
float
dropout_prob
);
torch
::
Tensor
bwd_cuda
(
int
heads
,
torch
::
Tensor
const
&
output_grads
,
torch
::
Tensor
const
&
softmax_results
,
torch
::
Tensor
const
&
dropout_mask
,
float
dropout_prob
);
std
::
vector
<
torch
::
Tensor
>
fwd
(
bool
use_mask
,
bool
is_training
,
int
heads
,
torch
::
Tensor
const
&
input
,
torch
::
Tensor
const
&
pad_mask
,
float
dropout_prob
)
{
AT_ASSERTM
(
input
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
input
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
if
(
use_mask
)
{
AT_ASSERTM
(
pad_mask
.
dim
()
==
2
,
"expected 2D tensor"
);
AT_ASSERTM
(
pad_mask
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only BYTE is supported"
);
}
return
fwd_cuda
(
is_training
,
heads
,
input
,
use_mask
?
static_cast
<
const
half
*>
(
pad_mask
.
data_ptr
())
:
nullptr
,
dropout_prob
);
}
torch
::
Tensor
bwd
(
bool
use_mask
,
int
heads
,
torch
::
Tensor
const
&
output_grads
,
torch
::
Tensor
const
&
softmax_results
,
torch
::
Tensor
const
&
dropout_mask
,
float
dropout_prob
)
{
AT_ASSERTM
(
output_grads
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
softmax_results
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
dropout_mask
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
output_grads
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
softmax_results
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
// AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte,
// "Only BYTE is supported");
return
bwd_cuda
(
heads
,
output_grads
,
softmax_results
,
dropout_mask
,
dropout_prob
);
}
}
// namespace additive_mask_softmax_dropout
namespace
mask_softmax_dropout
{
std
::
vector
<
torch
::
Tensor
>
fwd_cuda
(
bool
is_training
,
int
heads
,
torch
::
Tensor
const
&
input
,
const
uint8_t
*
pad_mask
,
float
dropout_prob
);
torch
::
Tensor
bwd_cuda
(
int
heads
,
torch
::
Tensor
const
&
output_grads
,
torch
::
Tensor
const
&
softmax_results
,
torch
::
Tensor
const
&
dropout_mask
,
const
uint8_t
*
padding_mask
,
float
dropout_prob
);
std
::
vector
<
torch
::
Tensor
>
fwd
(
bool
use_mask
,
bool
is_training
,
int
heads
,
torch
::
Tensor
const
&
input
,
torch
::
Tensor
const
&
pad_mask
,
float
dropout_prob
)
{
AT_ASSERTM
(
input
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
input
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
if
(
use_mask
)
{
AT_ASSERTM
(
pad_mask
.
dim
()
==
2
,
"expected 2D tensor"
);
AT_ASSERTM
(
pad_mask
.
type
().
scalarType
()
==
at
::
ScalarType
::
Byte
,
"Only BYTE is supported"
);
}
return
fwd_cuda
(
is_training
,
heads
,
input
,
use_mask
?
static_cast
<
const
uint8_t
*>
(
pad_mask
.
data_ptr
())
:
nullptr
,
dropout_prob
);
}
torch
::
Tensor
bwd
(
bool
use_mask
,
int
heads
,
torch
::
Tensor
const
&
output_grads
,
torch
::
Tensor
const
&
softmax_results
,
torch
::
Tensor
const
&
dropout_mask
,
torch
::
Tensor
const
&
padding_mask
,
float
dropout_prob
)
{
AT_ASSERTM
(
output_grads
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
softmax_results
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
dropout_mask
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
output_grads
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
softmax_results
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
// AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte,
// "Only BYTE is supported");
return
bwd_cuda
(
heads
,
output_grads
,
softmax_results
,
dropout_mask
,
use_mask
?
static_cast
<
const
uint8_t
*>
(
padding_mask
.
data_ptr
())
:
nullptr
,
dropout_prob
);
}
}
// end namespace mask_softmax_dropout
}
// end namespace fused_softmax
namespace
encdec
{
namespace
rocblas_gemmex
{
std
::
vector
<
torch
::
Tensor
>
fwd_cuda
(
bool
use_time_mask
,
bool
is_training
,
int
heads
,
torch
::
Tensor
const
&
inputs_q
,
torch
::
Tensor
const
&
inputs_kv
,
torch
::
Tensor
const
&
input_weights_q
,
torch
::
Tensor
const
&
input_weights_kv
,
torch
::
Tensor
const
&
output_weights
,
const
uint8_t
*
pad_mask
,
float
dropout_prob
);
std
::
vector
<
torch
::
Tensor
>
bwd_cuda
(
int
heads
,
torch
::
Tensor
const
&
output_grads
,
torch
::
Tensor
const
&
matmul2_results
,
torch
::
Tensor
const
&
dropout_results
,
torch
::
Tensor
const
&
softmax_results
,
torch
::
Tensor
const
&
input_lin_q_results
,
torch
::
Tensor
const
&
input_lin_kv_results
,
torch
::
Tensor
const
&
inputs_q
,
torch
::
Tensor
const
&
inputs_kv
,
torch
::
Tensor
const
&
input_weights_q
,
torch
::
Tensor
const
&
input_weights_kv
,
torch
::
Tensor
const
&
output_weights
,
torch
::
Tensor
const
&
dropout_mask
,
float
dropout_prob
);
std
::
vector
<
torch
::
Tensor
>
fwd
(
bool
use_mask
,
bool
use_time_mask
,
bool
is_training
,
int
heads
,
torch
::
Tensor
const
&
inputs_q
,
torch
::
Tensor
const
&
inputs_kv
,
torch
::
Tensor
const
&
input_weights_q
,
torch
::
Tensor
const
&
input_weights_kv
,
torch
::
Tensor
const
&
output_weights
,
torch
::
Tensor
const
&
pad_mask
,
float
dropout_prob
)
{
AT_ASSERTM
(
inputs_q
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
inputs_kv
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
input_weights_q
.
dim
()
==
2
,
"expected 2D tensor"
);
AT_ASSERTM
(
input_weights_kv
.
dim
()
==
2
,
"expected 2D tensor"
);
AT_ASSERTM
(
output_weights
.
dim
()
==
2
,
"expected 2D tensor"
);
AT_ASSERTM
(
inputs_q
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
inputs_kv
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
input_weights_q
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
input_weights_kv
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
output_weights
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
if
(
use_mask
)
{
AT_ASSERTM
(
pad_mask
.
dim
()
==
2
,
"expected 2D tensor"
);
AT_ASSERTM
(
pad_mask
.
type
().
scalarType
()
==
at
::
ScalarType
::
Byte
,
"Only BYTE is supported"
);
}
return
fwd_cuda
(
use_time_mask
,
is_training
,
heads
,
inputs_q
,
inputs_kv
,
input_weights_q
,
input_weights_kv
,
output_weights
,
use_mask
?
static_cast
<
const
uint8_t
*>
(
pad_mask
.
data_ptr
())
:
nullptr
,
dropout_prob
);
}
std
::
vector
<
torch
::
Tensor
>
bwd
(
int
heads
,
torch
::
Tensor
const
&
output_grads
,
torch
::
Tensor
const
&
matmul2_results
,
torch
::
Tensor
const
&
dropout_results
,
torch
::
Tensor
const
&
softmax_results
,
torch
::
Tensor
const
&
input_lin_q_results
,
torch
::
Tensor
const
&
input_lin_kv_results
,
torch
::
Tensor
const
&
inputs_q
,
torch
::
Tensor
const
&
inputs_kv
,
torch
::
Tensor
const
&
input_weights_q
,
torch
::
Tensor
const
&
input_weights_kv
,
torch
::
Tensor
const
&
output_weights
,
torch
::
Tensor
const
&
dropout_mask
,
float
dropout_prob
)
{
AT_ASSERTM
(
output_grads
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
matmul2_results
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
dropout_results
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
softmax_results
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
input_lin_q_results
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
input_lin_kv_results
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
inputs_q
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
inputs_kv
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
input_weights_q
.
dim
()
==
2
,
"expected 2D tensor"
);
AT_ASSERTM
(
input_weights_kv
.
dim
()
==
2
,
"expected 2D tensor"
);
AT_ASSERTM
(
output_weights
.
dim
()
==
2
,
"expected 2D tensor"
);
AT_ASSERTM
(
dropout_mask
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
output_grads
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
matmul2_results
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
dropout_results
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
softmax_results
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
input_lin_q_results
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
input_lin_kv_results
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
inputs_q
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
inputs_kv
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
input_weights_q
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
input_weights_kv
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
output_weights
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
dropout_mask
.
type
().
scalarType
()
==
at
::
ScalarType
::
Byte
,
"Only BYTE is supported"
);
return
bwd_cuda
(
heads
,
output_grads
,
matmul2_results
,
dropout_results
,
softmax_results
,
input_lin_q_results
,
input_lin_kv_results
,
inputs_q
,
inputs_kv
,
input_weights_q
,
input_weights_kv
,
output_weights
,
dropout_mask
,
dropout_prob
);
}
}
// end namespace rocblas_gemmex
}
// end namespace encdec
namespace
encdec_norm_add
{
namespace
rocblas_gemmex
{
std
::
vector
<
torch
::
Tensor
>
fwd_cuda
(
bool
use_time_mask
,
bool
is_training
,
int
heads
,
torch
::
Tensor
const
&
inputs_q
,
torch
::
Tensor
const
&
inputs_kv
,
torch
::
Tensor
const
&
lyr_nrm_gamma_weights
,
torch
::
Tensor
const
&
lyr_nrm_beta_weights
,
torch
::
Tensor
const
&
input_weights_q
,
torch
::
Tensor
const
&
input_weights_kv
,
torch
::
Tensor
const
&
output_weights
,
const
uint8_t
*
pad_mask
,
float
dropout_prob
);
std
::
vector
<
torch
::
Tensor
>
bwd_cuda
(
int
heads
,
torch
::
Tensor
const
&
output_grads
,
torch
::
Tensor
const
&
matmul2_results
,
torch
::
Tensor
const
&
dropout_results
,
torch
::
Tensor
const
&
softmax_results
,
torch
::
Tensor
const
&
input_lin_q_results
,
torch
::
Tensor
const
&
input_lin_kv_results
,
torch
::
Tensor
const
&
lyr_nrm_results
,
torch
::
Tensor
const
&
lyr_nrm_mean
,
torch
::
Tensor
const
&
lyr_nrm_invvar
,
torch
::
Tensor
const
&
inputs_q
,
torch
::
Tensor
const
&
inputs_kv
,
torch
::
Tensor
const
&
lyr_nrm_gamma_weights
,
torch
::
Tensor
const
&
lyr_nrm_beta_weights
,
torch
::
Tensor
const
&
input_weights_q
,
torch
::
Tensor
const
&
input_weights_kv
,
torch
::
Tensor
const
&
output_weights
,
torch
::
Tensor
const
&
dropout_mask
,
torch
::
Tensor
const
&
dropout_add_mask
,
float
dropout_prob
);
std
::
vector
<
torch
::
Tensor
>
fwd
(
bool
use_mask
,
bool
use_time_mask
,
bool
is_training
,
int
heads
,
torch
::
Tensor
const
&
inputs_q
,
torch
::
Tensor
const
&
inputs_kv
,
torch
::
Tensor
const
&
lyr_nrm_gamma_weights
,
torch
::
Tensor
const
&
lyr_nrm_beta_weights
,
torch
::
Tensor
const
&
input_weights_q
,
torch
::
Tensor
const
&
input_weights_kv
,
torch
::
Tensor
const
&
output_weights
,
torch
::
Tensor
const
&
pad_mask
,
float
dropout_prob
)
{
AT_ASSERTM
(
inputs_q
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
inputs_kv
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
lyr_nrm_gamma_weights
.
dim
()
==
1
,
"expected 1D tensor"
);
AT_ASSERTM
(
lyr_nrm_beta_weights
.
dim
()
==
1
,
"expected 1D tensor"
);
AT_ASSERTM
(
input_weights_q
.
dim
()
==
2
,
"expected 2D tensor"
);
AT_ASSERTM
(
input_weights_kv
.
dim
()
==
2
,
"expected 2D tensor"
);
AT_ASSERTM
(
output_weights
.
dim
()
==
2
,
"expected 2D tensor"
);
AT_ASSERTM
(
inputs_q
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
inputs_kv
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
lyr_nrm_gamma_weights
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
lyr_nrm_beta_weights
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
input_weights_q
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
input_weights_kv
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
output_weights
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
if
(
use_mask
)
{
AT_ASSERTM
(
pad_mask
.
dim
()
==
2
,
"expected 2D tensor"
);
AT_ASSERTM
(
pad_mask
.
type
().
scalarType
()
==
at
::
ScalarType
::
Byte
,
"Only BYTE is supported"
);
}
return
fwd_cuda
(
use_time_mask
,
is_training
,
heads
,
inputs_q
,
inputs_kv
,
lyr_nrm_gamma_weights
,
lyr_nrm_beta_weights
,
input_weights_q
,
input_weights_kv
,
output_weights
,
use_mask
?
static_cast
<
const
uint8_t
*>
(
pad_mask
.
data_ptr
())
:
nullptr
,
dropout_prob
);
}
std
::
vector
<
torch
::
Tensor
>
bwd
(
int
heads
,
torch
::
Tensor
const
&
output_grads
,
torch
::
Tensor
const
&
matmul2_results
,
torch
::
Tensor
const
&
dropout_results
,
torch
::
Tensor
const
&
softmax_results
,
torch
::
Tensor
const
&
input_lin_q_results
,
torch
::
Tensor
const
&
input_lin_kv_results
,
torch
::
Tensor
const
&
lyr_nrm_results
,
torch
::
Tensor
const
&
lyr_nrm_mean
,
torch
::
Tensor
const
&
lyr_nrm_invvar
,
torch
::
Tensor
const
&
inputs_q
,
torch
::
Tensor
const
&
inputs_kv
,
torch
::
Tensor
const
&
lyr_nrm_gamma_weights
,
torch
::
Tensor
const
&
lyr_nrm_beta_weights
,
torch
::
Tensor
const
&
input_weights_q
,
torch
::
Tensor
const
&
input_weights_kv
,
torch
::
Tensor
const
&
output_weights
,
torch
::
Tensor
const
&
dropout_mask
,
torch
::
Tensor
const
&
dropout_add_mask
,
float
dropout_prob
)
{
AT_ASSERTM
(
output_grads
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
matmul2_results
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
dropout_results
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
softmax_results
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
input_lin_q_results
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
input_lin_kv_results
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
lyr_nrm_results
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
lyr_nrm_mean
.
dim
()
==
1
,
"expected 1D tensor"
);
AT_ASSERTM
(
lyr_nrm_invvar
.
dim
()
==
1
,
"expected 1D tensor"
);
AT_ASSERTM
(
inputs_q
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
inputs_kv
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
lyr_nrm_gamma_weights
.
dim
()
==
1
,
"expected 1D tensor"
);
AT_ASSERTM
(
lyr_nrm_beta_weights
.
dim
()
==
1
,
"expected 1D tensor"
);
AT_ASSERTM
(
input_weights_q
.
dim
()
==
2
,
"expected 2D tensor"
);
AT_ASSERTM
(
input_weights_kv
.
dim
()
==
2
,
"expected 2D tensor"
);
AT_ASSERTM
(
output_weights
.
dim
()
==
2
,
"expected 2D tensor"
);
AT_ASSERTM
(
dropout_mask
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
dropout_add_mask
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
output_grads
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
matmul2_results
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
dropout_results
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
softmax_results
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
input_lin_q_results
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
input_lin_kv_results
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
lyr_nrm_results
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
lyr_nrm_mean
.
type
().
scalarType
()
==
at
::
ScalarType
::
Float
,
"Only FLOAT is supported"
);
AT_ASSERTM
(
lyr_nrm_invvar
.
type
().
scalarType
()
==
at
::
ScalarType
::
Float
,
"Only FLOAT is supported"
);
AT_ASSERTM
(
inputs_q
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
inputs_kv
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
lyr_nrm_gamma_weights
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
lyr_nrm_beta_weights
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
input_weights_q
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
input_weights_kv
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
output_weights
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
dropout_mask
.
type
().
scalarType
()
==
at
::
ScalarType
::
Byte
,
"Only BYTE is supported"
);
AT_ASSERTM
(
dropout_add_mask
.
type
().
scalarType
()
==
at
::
ScalarType
::
Byte
,
"Only BYTE is supported"
);
return
bwd_cuda
(
heads
,
output_grads
,
matmul2_results
,
dropout_results
,
softmax_results
,
input_lin_q_results
,
input_lin_kv_results
,
lyr_nrm_results
,
lyr_nrm_mean
,
lyr_nrm_invvar
,
inputs_q
,
inputs_kv
,
lyr_nrm_gamma_weights
,
lyr_nrm_beta_weights
,
input_weights_q
,
input_weights_kv
,
output_weights
,
dropout_mask
,
dropout_add_mask
,
dropout_prob
);
}
}
// end namespace rocblas_gemmex
}
// end namespace encdec_norm_add
namespace
self
{
namespace
rocblas_gemmex
{
std
::
vector
<
torch
::
Tensor
>
fwd_cuda
(
bool
use_time_mask
,
bool
is_training
,
int
heads
,
torch
::
Tensor
const
&
inputs
,
torch
::
Tensor
const
&
input_weights
,
torch
::
Tensor
const
&
output_weights
,
const
uint8_t
*
pad_mask
,
float
dropout_prob
);
std
::
vector
<
torch
::
Tensor
>
bwd_cuda
(
int
heads
,
torch
::
Tensor
const
&
output_grads
,
torch
::
Tensor
const
&
matmul2_results
,
torch
::
Tensor
const
&
dropout_results
,
torch
::
Tensor
const
&
softmax_results
,
torch
::
Tensor
const
&
input_lin_results
,
torch
::
Tensor
const
&
inputs
,
torch
::
Tensor
const
&
input_weights
,
torch
::
Tensor
const
&
output_weights
,
torch
::
Tensor
const
&
dropout_mask
,
float
dropout_prob
);
std
::
vector
<
torch
::
Tensor
>
fwd
(
bool
use_mask
,
bool
use_time_mask
,
bool
is_training
,
int
heads
,
torch
::
Tensor
const
&
inputs
,
torch
::
Tensor
const
&
input_weights
,
torch
::
Tensor
const
&
output_weights
,
torch
::
Tensor
const
&
pad_mask
,
float
dropout_prob
)
{
AT_ASSERTM
(
inputs
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
input_weights
.
dim
()
==
2
,
"expected 2D tensor"
);
AT_ASSERTM
(
output_weights
.
dim
()
==
2
,
"expected 2D tensor"
);
AT_ASSERTM
(
inputs
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
input_weights
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
output_weights
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
if
(
use_mask
)
{
AT_ASSERTM
(
pad_mask
.
dim
()
==
2
,
"expected 2D tensor"
);
AT_ASSERTM
(
pad_mask
.
type
().
scalarType
()
==
at
::
ScalarType
::
Byte
,
"Only BYTE is supported"
);
}
return
fwd_cuda
(
use_time_mask
,
is_training
,
heads
,
inputs
,
input_weights
,
output_weights
,
use_mask
?
static_cast
<
const
uint8_t
*>
(
pad_mask
.
data_ptr
())
:
nullptr
,
dropout_prob
);
}
std
::
vector
<
torch
::
Tensor
>
bwd
(
int
heads
,
torch
::
Tensor
const
&
output_grads
,
torch
::
Tensor
const
&
matmul2_results
,
torch
::
Tensor
const
&
dropout_results
,
torch
::
Tensor
const
&
softmax_results
,
torch
::
Tensor
const
&
input_lin_results
,
torch
::
Tensor
const
&
inputs
,
torch
::
Tensor
const
&
input_weights
,
torch
::
Tensor
const
&
output_weights
,
torch
::
Tensor
const
&
dropout_mask
,
float
dropout_prob
)
{
AT_ASSERTM
(
output_grads
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
matmul2_results
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
dropout_results
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
softmax_results
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
input_lin_results
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
inputs
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
input_weights
.
dim
()
==
2
,
"expected 2D tensor"
);
AT_ASSERTM
(
output_weights
.
dim
()
==
2
,
"expected 2D tensor"
);
AT_ASSERTM
(
dropout_mask
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
output_grads
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
matmul2_results
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
dropout_results
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
softmax_results
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
input_lin_results
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
inputs
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
input_weights
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
output_weights
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
dropout_mask
.
type
().
scalarType
()
==
at
::
ScalarType
::
Byte
,
"Only BYTE is supported"
);
return
bwd_cuda
(
heads
,
output_grads
,
matmul2_results
,
dropout_results
,
softmax_results
,
input_lin_results
,
inputs
,
input_weights
,
output_weights
,
dropout_mask
,
dropout_prob
);
}
}
// end namespace rocblas_gemmex
}
// end namespace self
namespace
self_bias
{
namespace
rocblas_gemmex
{
std
::
vector
<
torch
::
Tensor
>
fwd_cuda
(
bool
use_time_mask
,
bool
is_training
,
int
heads
,
torch
::
Tensor
const
&
inputs
,
torch
::
Tensor
const
&
input_weights
,
torch
::
Tensor
const
&
output_weights
,
torch
::
Tensor
const
&
input_biases
,
torch
::
Tensor
const
&
output_biases
,
const
uint8_t
*
pad_mask
,
float
dropout_prob
);
std
::
vector
<
torch
::
Tensor
>
bwd_cuda
(
int
heads
,
torch
::
Tensor
const
&
output_grads
,
torch
::
Tensor
const
&
matmul2_results
,
torch
::
Tensor
const
&
dropout_results
,
torch
::
Tensor
const
&
softmax_results
,
torch
::
Tensor
const
&
input_lin_results
,
torch
::
Tensor
const
&
inputs
,
torch
::
Tensor
const
&
input_weights
,
torch
::
Tensor
const
&
output_weights
,
// torch::Tensor const& input_biases,
// torch::Tensor const& output_biases,
torch
::
Tensor
const
&
dropout_mask
,
float
dropout_prob
);
std
::
vector
<
torch
::
Tensor
>
fwd
(
bool
use_mask
,
bool
use_time_mask
,
bool
is_training
,
int
heads
,
torch
::
Tensor
const
&
inputs
,
torch
::
Tensor
const
&
input_weights
,
torch
::
Tensor
const
&
output_weights
,
torch
::
Tensor
const
&
input_biases
,
torch
::
Tensor
const
&
output_biases
,
torch
::
Tensor
const
&
pad_mask
,
float
dropout_prob
)
{
AT_ASSERTM
(
inputs
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
input_weights
.
dim
()
==
2
,
"expected 2D tensor"
);
AT_ASSERTM
(
output_weights
.
dim
()
==
2
,
"expected 2D tensor"
);
AT_ASSERTM
(
inputs
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
input_weights
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
output_weights
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
if
(
use_mask
)
{
AT_ASSERTM
(
pad_mask
.
dim
()
==
2
,
"expected 2D tensor"
);
AT_ASSERTM
(
pad_mask
.
type
().
scalarType
()
==
at
::
ScalarType
::
Byte
,
"Only BYTE is supported"
);
}
return
fwd_cuda
(
use_time_mask
,
is_training
,
heads
,
inputs
,
input_weights
,
output_weights
,
input_biases
,
output_biases
,
use_mask
?
static_cast
<
const
uint8_t
*>
(
pad_mask
.
data_ptr
())
:
nullptr
,
dropout_prob
);
}
std
::
vector
<
torch
::
Tensor
>
bwd
(
int
heads
,
torch
::
Tensor
const
&
output_grads
,
torch
::
Tensor
const
&
matmul2_results
,
torch
::
Tensor
const
&
dropout_results
,
torch
::
Tensor
const
&
softmax_results
,
torch
::
Tensor
const
&
input_lin_results
,
torch
::
Tensor
const
&
inputs
,
torch
::
Tensor
const
&
input_weights
,
torch
::
Tensor
const
&
output_weights
,
torch
::
Tensor
const
&
dropout_mask
,
float
dropout_prob
)
{
AT_ASSERTM
(
output_grads
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
matmul2_results
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
dropout_results
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
softmax_results
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
input_lin_results
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
inputs
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
input_weights
.
dim
()
==
2
,
"expected 2D tensor"
);
AT_ASSERTM
(
output_weights
.
dim
()
==
2
,
"expected 2D tensor"
);
AT_ASSERTM
(
dropout_mask
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
output_grads
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
matmul2_results
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
dropout_results
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
softmax_results
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
input_lin_results
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
inputs
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
input_weights
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
output_weights
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
dropout_mask
.
type
().
scalarType
()
==
at
::
ScalarType
::
Byte
,
"Only BYTE is supported"
);
return
bwd_cuda
(
heads
,
output_grads
,
matmul2_results
,
dropout_results
,
softmax_results
,
input_lin_results
,
inputs
,
input_weights
,
output_weights
,
dropout_mask
,
dropout_prob
);
}
}
// end namespace rocblas_gemmex
}
// namespace self_bias
namespace
self_bias_additive_mask
{
namespace
rocblas_gemmex
{
std
::
vector
<
torch
::
Tensor
>
fwd_cuda
(
bool
use_time_mask
,
bool
is_training
,
int
heads
,
torch
::
Tensor
const
&
inputs
,
torch
::
Tensor
const
&
input_weights
,
torch
::
Tensor
const
&
output_weights
,
torch
::
Tensor
const
&
input_biases
,
torch
::
Tensor
const
&
output_biases
,
const
half
*
pad_mask
,
float
dropout_prob
);
std
::
vector
<
torch
::
Tensor
>
bwd_cuda
(
int
heads
,
torch
::
Tensor
const
&
output_grads
,
torch
::
Tensor
const
&
matmul2_results
,
torch
::
Tensor
const
&
dropout_results
,
// torch::Tensor const& softmax_results,
torch
::
Tensor
const
&
bmm1_results
,
torch
::
Tensor
const
&
pad_mask
,
torch
::
Tensor
const
&
input_lin_results
,
torch
::
Tensor
const
&
inputs
,
torch
::
Tensor
const
&
input_weights
,
torch
::
Tensor
const
&
output_weights
,
// torch::Tensor const& input_biases,
// torch::Tensor const& output_biases,
torch
::
Tensor
const
&
dropout_mask
,
float
dropout_prob
);
std
::
vector
<
torch
::
Tensor
>
fwd
(
bool
use_mask
,
bool
use_time_mask
,
bool
is_training
,
int
heads
,
torch
::
Tensor
const
&
inputs
,
torch
::
Tensor
const
&
input_weights
,
torch
::
Tensor
const
&
output_weights
,
torch
::
Tensor
const
&
input_biases
,
torch
::
Tensor
const
&
output_biases
,
torch
::
Tensor
const
&
pad_mask
,
float
dropout_prob
)
{
AT_ASSERTM
(
inputs
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
input_weights
.
dim
()
==
2
,
"expected 2D tensor"
);
AT_ASSERTM
(
output_weights
.
dim
()
==
2
,
"expected 2D tensor"
);
AT_ASSERTM
(
inputs
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
input_weights
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
output_weights
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
use_mask
,
"no mask is not supported"
);
if
(
use_mask
)
{
AT_ASSERTM
(
pad_mask
.
dim
()
==
2
,
"expected 2D tensor"
);
AT_ASSERTM
(
pad_mask
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only Half is supported"
);
}
return
fwd_cuda
(
use_time_mask
,
is_training
,
heads
,
inputs
,
input_weights
,
output_weights
,
input_biases
,
output_biases
,
use_mask
?
static_cast
<
const
half
*>
(
pad_mask
.
data_ptr
())
:
nullptr
,
dropout_prob
);
}
std
::
vector
<
torch
::
Tensor
>
bwd
(
int
heads
,
torch
::
Tensor
const
&
output_grads
,
torch
::
Tensor
const
&
matmul2_results
,
torch
::
Tensor
const
&
dropout_results
,
torch
::
Tensor
const
&
bmm1_results
,
torch
::
Tensor
const
&
pad_mask
,
torch
::
Tensor
const
&
input_lin_results
,
torch
::
Tensor
const
&
inputs
,
torch
::
Tensor
const
&
input_weights
,
torch
::
Tensor
const
&
output_weights
,
torch
::
Tensor
const
&
dropout_mask
,
float
dropout_prob
)
{
AT_ASSERTM
(
output_grads
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
matmul2_results
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
dropout_results
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
input_lin_results
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
inputs
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
input_weights
.
dim
()
==
2
,
"expected 2D tensor"
);
AT_ASSERTM
(
output_weights
.
dim
()
==
2
,
"expected 2D tensor"
);
AT_ASSERTM
(
dropout_mask
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
output_grads
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
matmul2_results
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
dropout_results
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
input_lin_results
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
inputs
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
input_weights
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
output_weights
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
dropout_mask
.
type
().
scalarType
()
==
at
::
ScalarType
::
Byte
,
"Only BYTE is supported"
);
return
bwd_cuda
(
heads
,
output_grads
,
matmul2_results
,
dropout_results
,
bmm1_results
,
pad_mask
,
input_lin_results
,
inputs
,
input_weights
,
output_weights
,
dropout_mask
,
dropout_prob
);
}
}
// end namespace rocblas_gemmex
}
// namespace self_bias_additive_mask
namespace
self_norm_add
{
namespace
rocblas_gemmex
{
std
::
vector
<
torch
::
Tensor
>
fwd_cuda
(
bool
use_time_mask
,
bool
is_training
,
int
heads
,
torch
::
Tensor
const
&
inputs
,
torch
::
Tensor
const
&
lyr_nrm_gamma_weights
,
torch
::
Tensor
const
&
lyr_nrm_beta_weights
,
torch
::
Tensor
const
&
input_weights
,
torch
::
Tensor
const
&
output_weights
,
const
uint8_t
*
pad_mask
,
float
dropout_prob
);
std
::
vector
<
torch
::
Tensor
>
bwd_cuda
(
int
heads
,
torch
::
Tensor
const
&
output_grads
,
torch
::
Tensor
const
&
matmul2_results
,
torch
::
Tensor
const
&
dropout_results
,
torch
::
Tensor
const
&
softmax_results
,
torch
::
Tensor
const
&
input_lin_results
,
torch
::
Tensor
const
&
lyr_nrm_results
,
torch
::
Tensor
const
&
lyr_nrm_mean
,
torch
::
Tensor
const
&
lyr_nrm_invvar
,
torch
::
Tensor
const
&
inputs
,
torch
::
Tensor
const
&
lyr_nrm_gamma_weights
,
torch
::
Tensor
const
&
lyr_nrm_beta_weights
,
torch
::
Tensor
const
&
input_weights
,
torch
::
Tensor
const
&
output_weights
,
torch
::
Tensor
const
&
dropout_mask
,
torch
::
Tensor
const
&
dropout_add_mask
,
float
dropout_prob
);
std
::
vector
<
torch
::
Tensor
>
fwd
(
bool
use_mask
,
bool
use_time_mask
,
bool
is_training
,
int
heads
,
torch
::
Tensor
const
&
inputs
,
torch
::
Tensor
const
&
lyr_nrm_gamma_weights
,
torch
::
Tensor
const
&
lyr_nrm_beta_weights
,
torch
::
Tensor
const
&
input_weights
,
torch
::
Tensor
const
&
output_weights
,
torch
::
Tensor
const
&
pad_mask
,
float
dropout_prob
)
{
AT_ASSERTM
(
inputs
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
lyr_nrm_gamma_weights
.
dim
()
==
1
,
"expected 1D tensor"
);
AT_ASSERTM
(
lyr_nrm_beta_weights
.
dim
()
==
1
,
"expected 1D tensor"
);
AT_ASSERTM
(
input_weights
.
dim
()
==
2
,
"expected 2D tensor"
);
AT_ASSERTM
(
output_weights
.
dim
()
==
2
,
"expected 2D tensor"
);
AT_ASSERTM
(
inputs
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
lyr_nrm_gamma_weights
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
lyr_nrm_beta_weights
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
input_weights
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
output_weights
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
if
(
use_mask
)
{
AT_ASSERTM
(
pad_mask
.
dim
()
==
2
,
"expected 2D tensor"
);
AT_ASSERTM
(
pad_mask
.
type
().
scalarType
()
==
at
::
ScalarType
::
Byte
,
"Only BYTE is supported"
);
}
return
fwd_cuda
(
use_time_mask
,
is_training
,
heads
,
inputs
,
lyr_nrm_gamma_weights
,
lyr_nrm_beta_weights
,
input_weights
,
output_weights
,
use_mask
?
static_cast
<
const
uint8_t
*>
(
pad_mask
.
data_ptr
())
:
nullptr
,
dropout_prob
);
}
std
::
vector
<
torch
::
Tensor
>
bwd
(
int
heads
,
torch
::
Tensor
const
&
output_grads
,
torch
::
Tensor
const
&
matmul2_results
,
torch
::
Tensor
const
&
dropout_results
,
torch
::
Tensor
const
&
softmax_results
,
torch
::
Tensor
const
&
input_lin_results
,
torch
::
Tensor
const
&
lyr_nrm_results
,
torch
::
Tensor
const
&
lyr_nrm_mean
,
torch
::
Tensor
const
&
lyr_nrm_invvar
,
torch
::
Tensor
const
&
inputs
,
torch
::
Tensor
const
&
lyr_nrm_gamma_weights
,
torch
::
Tensor
const
&
lyr_nrm_beta_weights
,
torch
::
Tensor
const
&
input_weights
,
torch
::
Tensor
const
&
output_weights
,
torch
::
Tensor
const
&
dropout_mask
,
torch
::
Tensor
const
&
dropout_add_mask
,
float
dropout_prob
)
{
AT_ASSERTM
(
output_grads
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
matmul2_results
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
dropout_results
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
softmax_results
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
input_lin_results
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
lyr_nrm_results
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
lyr_nrm_mean
.
dim
()
==
1
,
"expected 1D tensor"
);
AT_ASSERTM
(
lyr_nrm_invvar
.
dim
()
==
1
,
"expected 1D tensor"
);
AT_ASSERTM
(
inputs
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
lyr_nrm_gamma_weights
.
dim
()
==
1
,
"expected 1D tensor"
);
AT_ASSERTM
(
lyr_nrm_beta_weights
.
dim
()
==
1
,
"expected 1D tensor"
);
AT_ASSERTM
(
input_weights
.
dim
()
==
2
,
"expected 2D tensor"
);
AT_ASSERTM
(
output_weights
.
dim
()
==
2
,
"expected 2D tensor"
);
AT_ASSERTM
(
dropout_mask
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
dropout_add_mask
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
output_grads
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
matmul2_results
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
dropout_results
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
softmax_results
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
input_lin_results
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
lyr_nrm_results
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
lyr_nrm_mean
.
type
().
scalarType
()
==
at
::
ScalarType
::
Float
,
"Only FLOAT is supported"
);
AT_ASSERTM
(
lyr_nrm_invvar
.
type
().
scalarType
()
==
at
::
ScalarType
::
Float
,
"Only FLOAT is supported"
);
AT_ASSERTM
(
inputs
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
lyr_nrm_gamma_weights
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
lyr_nrm_beta_weights
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
input_weights
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
output_weights
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
dropout_mask
.
type
().
scalarType
()
==
at
::
ScalarType
::
Byte
,
"Only BYTE is supported"
);
AT_ASSERTM
(
dropout_add_mask
.
type
().
scalarType
()
==
at
::
ScalarType
::
Byte
,
"Only BYTE is supported"
);
return
bwd_cuda
(
heads
,
output_grads
,
matmul2_results
,
dropout_results
,
softmax_results
,
input_lin_results
,
lyr_nrm_results
,
lyr_nrm_mean
,
lyr_nrm_invvar
,
inputs
,
lyr_nrm_gamma_weights
,
lyr_nrm_beta_weights
,
input_weights
,
output_weights
,
dropout_mask
,
dropout_add_mask
,
dropout_prob
);
}
}
// end namespace rocblas_gemmex
}
// end namespace self_norm_add
}
// end namespace multihead_attn
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"additive_mask_softmax_dropout_forward"
,
&
multihead_attn
::
fused_softmax
::
additive_mask_softmax_dropout
::
fwd
,
"Self Multihead Attention masked softmax dropout -- Forward."
);
m
.
def
(
"additive_mask_softmax_dropout_backward"
,
&
multihead_attn
::
fused_softmax
::
additive_mask_softmax_dropout
::
bwd
,
"Self Multihead Attention masked softmax dropout -- Backward."
);
m
.
def
(
"mask_softmax_dropout_forward"
,
&
multihead_attn
::
fused_softmax
::
mask_softmax_dropout
::
fwd
,
"Self Multihead Attention masked softmax dropout -- Forward."
);
m
.
def
(
"mask_softmax_dropout_backward"
,
&
multihead_attn
::
fused_softmax
::
mask_softmax_dropout
::
bwd
,
"Self Multihead Attention masked softmax dropout -- Backward."
);
m
.
def
(
"encdec_multihead_attn_forward"
,
&
multihead_attn
::
encdec
::
rocblas_gemmex
::
fwd
,
"Encdec Multihead Attention Forward."
);
m
.
def
(
"encdec_multihead_attn_backward"
,
&
multihead_attn
::
encdec
::
rocblas_gemmex
::
bwd
,
"Encdec Multihead Attention Backward."
);
m
.
def
(
"encdec_multihead_attn_norm_add_forward"
,
&
multihead_attn
::
encdec_norm_add
::
rocblas_gemmex
::
fwd
,
"Encdec Multihead Attention Plus Layer Norm and Residual Add Forward."
);
m
.
def
(
"encdec_multihead_attn_norm_add_backward"
,
&
multihead_attn
::
encdec_norm_add
::
rocblas_gemmex
::
bwd
,
"Encdec Multihead Attention Plus Layer Norm and Residual Add Backward."
);
m
.
def
(
"self_attn_forward"
,
&
multihead_attn
::
self
::
rocblas_gemmex
::
fwd
,
"Self Multihead Attention Forward."
);
m
.
def
(
"self_attn_backward"
,
&
multihead_attn
::
self
::
rocblas_gemmex
::
bwd
,
"Self Multihead Attention Backward."
);
m
.
def
(
"self_attn_bias_forward"
,
&
multihead_attn
::
self_bias
::
rocblas_gemmex
::
fwd
,
"Self Multihead Attention with Bias -- Forward."
);
m
.
def
(
"self_attn_bias_backward"
,
&
multihead_attn
::
self_bias
::
rocblas_gemmex
::
bwd
,
"Self Multihead Attention with Bias -- Backward."
);
m
.
def
(
"self_attn_bias_additive_mask_forward"
,
&
multihead_attn
::
self_bias_additive_mask
::
rocblas_gemmex
::
fwd
,
"Self Multihead Attention with Bias -- Forward."
);
m
.
def
(
"self_attn_bias_additive_mask_backward"
,
&
multihead_attn
::
self_bias_additive_mask
::
rocblas_gemmex
::
bwd
,
"Self Multihead Attention with Bias -- Backward."
);
m
.
def
(
"self_attn_norm_add_forward"
,
&
multihead_attn
::
self_norm_add
::
rocblas_gemmex
::
fwd
,
"Self Multihead Attention Plus Layer Norm and Residual Add Forward."
);
m
.
def
(
"self_attn_norm_add_backward"
,
&
multihead_attn
::
self_norm_add
::
rocblas_gemmex
::
bwd
,
"Self Multihead Attention Plus Layer Norm and Residual Add Backward."
);
}
#undef CHECK_CUDA
#undef CHECK_CONTIGUOUS
#undef CHECK_INPUT
apex/contrib/csrc/multihead_attn/philox.cuh
deleted
100644 → 0
View file @
2a4864d5
#pragma once
// Philox CUDA.
namespace
{
class
Philox
{
public:
__device__
inline
Philox
(
unsigned
long
long
seed
,
unsigned
long
long
subsequence
,
unsigned
long
long
offset
)
{
key
.
x
=
(
unsigned
int
)
seed
;
key
.
y
=
(
unsigned
int
)(
seed
>>
32
);
counter
=
make_uint4
(
0
,
0
,
0
,
0
);
counter
.
z
=
(
unsigned
int
)(
subsequence
);
counter
.
w
=
(
unsigned
int
)(
subsequence
>>
32
);
STATE
=
0
;
incr_n
(
offset
/
4
);
}
__device__
inline
uint4
operator
()()
{
if
(
STATE
==
0
)
{
uint4
counter_
=
counter
;
uint2
key_
=
key
;
// 7-round philox
for
(
int
i
=
0
;
i
<
6
;
i
++
)
{
counter_
=
single_round
(
counter_
,
key_
);
key_
.
x
+=
(
kPhilox10A
);
key_
.
y
+=
(
kPhilox10B
);
}
output
=
single_round
(
counter_
,
key_
);
incr
();
}
// return a float4 directly
// unsigned long ret;
// switch(STATE) {
// case 0: ret = output.x; break;
// case 1: ret = output.y; break;
// case 2: ret = output.z; break;
// case 3: ret = output.w; break;
//}
// STATE = (STATE + 1) % 4;
return
output
;
}
private:
uint4
counter
;
uint4
output
;
uint2
key
;
unsigned
int
STATE
;
__device__
inline
void
incr_n
(
unsigned
long
long
n
)
{
unsigned
int
nlo
=
(
unsigned
int
)(
n
);
unsigned
int
nhi
=
(
unsigned
int
)(
n
>>
32
);
counter
.
x
+=
nlo
;
if
(
counter
.
x
<
nlo
)
nhi
++
;
counter
.
y
+=
nhi
;
if
(
nhi
<=
counter
.
y
)
return
;
if
(
++
counter
.
z
)
return
;
++
counter
.
w
;
}
__device__
inline
void
incr
()
{
if
(
++
counter
.
x
)
return
;
if
(
++
counter
.
y
)
return
;
if
(
++
counter
.
z
)
return
;
++
counter
.
w
;
}
__device__
unsigned
int
mulhilo32
(
unsigned
int
a
,
unsigned
int
b
,
unsigned
int
*
result_high
)
{
*
result_high
=
__umulhi
(
a
,
b
);
return
a
*
b
;
}
__device__
inline
uint4
single_round
(
uint4
ctr
,
uint2
key
)
{
unsigned
int
hi0
;
unsigned
int
hi1
;
unsigned
int
lo0
=
mulhilo32
(
kPhiloxSA
,
ctr
.
x
,
&
hi0
);
unsigned
int
lo1
=
mulhilo32
(
kPhiloxSB
,
ctr
.
z
,
&
hi1
);
uint4
ret
=
{
hi1
^
ctr
.
y
^
key
.
x
,
lo1
,
hi0
^
ctr
.
w
^
key
.
y
,
lo0
};
return
ret
;
}
static
const
unsigned
long
kPhilox10A
=
0x9E3779B9
;
static
const
unsigned
long
kPhilox10B
=
0xBB67AE85
;
static
const
unsigned
long
kPhiloxSA
=
0xD2511F53
;
static
const
unsigned
long
kPhiloxSB
=
0xCD9E8D57
;
};
// Inverse of 2^32.
constexpr
float
M_RAN_INVM32
=
2.3283064e-10
f
;
__device__
__inline__
float4
uniform4
(
uint4
x
)
{
return
make_float4
(
x
.
x
*
M_RAN_INVM32
,
x
.
y
*
M_RAN_INVM32
,
x
.
z
*
M_RAN_INVM32
,
x
.
w
*
M_RAN_INVM32
);
}
}
// namespace
apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_additive_mask_cuda.cu
deleted
100644 → 0
View file @
2a4864d5
#include <iostream>
#include <math.h>
#include <vector>
#include <cuda.h>
#include <cuda_fp16.h>
//#include <cuda_profiler_api.h>
#include <cuda_runtime.h>
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include "dropout.cuh"
#include "softmax.cuh"
#include "strided_batched_gemm.cuh"
namespace
multihead_attn
{
namespace
self_bias_additive_mask
{
namespace
rocblas_gemmex
{
std
::
vector
<
torch
::
Tensor
>
fwd_cuda
(
bool
use_time_mask
,
bool
is_training
,
int
heads
,
torch
::
Tensor
const
&
inputs
,
torch
::
Tensor
const
&
input_weights
,
torch
::
Tensor
const
&
output_weights
,
torch
::
Tensor
const
&
input_biases
,
torch
::
Tensor
const
&
output_biases
,
const
half
*
pad_mask
,
float
dropout_prob
)
{
const
int
embed_dim
=
inputs
.
size
(
2
);
const
int
sequences
=
inputs
.
size
(
1
);
const
int
q_seq_len
=
inputs
.
size
(
0
);
const
int
k_seq_len
=
q_seq_len
;
const
int
batches
=
sequences
*
q_seq_len
;
const
int
head_dim
=
embed_dim
/
heads
;
const
int
output_lin_dim
=
3
*
embed_dim
;
const
int
attn_batches
=
heads
*
sequences
;
const
int
lead_dim
=
attn_batches
*
3
*
head_dim
;
const
int
batch_stride
=
3
*
head_dim
;
const
int
dropout_elems
=
attn_batches
*
q_seq_len
*
k_seq_len
;
const
float
alpha
=
1.0
;
const
float
beta_zero
=
0.0
;
const
float
beta_one
=
1.0
;
const
float
scale
=
1.0
/
sqrt
(
static_cast
<
float
>
(
head_dim
));
// There is no reason to use more than one stream as every kernel is
// sequentially dependent
cublasHandle_t
handle
=
at
::
cuda
::
getCurrentCUDABlasHandle
();
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
cublasSetStream
(
handle
,
stream
);
// 3 Intermediate Results + Output (Note: dropout intermediates are generated
// by ATen library code)
auto
act_options
=
inputs
.
options
().
requires_grad
(
false
);
auto
mask_options
=
act_options
.
dtype
(
torch
::
kUInt8
);
torch
::
Tensor
input_lin_results
=
torch
::
empty
({
q_seq_len
,
sequences
,
output_lin_dim
},
act_options
);
torch
::
Tensor
bmm1_results
=
torch
::
empty
({
attn_batches
,
q_seq_len
,
k_seq_len
},
act_options
);
torch
::
Tensor
dropout_results
=
torch
::
empty
({
attn_batches
,
q_seq_len
,
k_seq_len
},
act_options
);
torch
::
Tensor
dropout_mask
=
torch
::
empty
({
attn_batches
,
q_seq_len
,
k_seq_len
},
mask_options
);
torch
::
Tensor
matmul2_results
=
torch
::
empty
({
q_seq_len
,
attn_batches
,
head_dim
},
act_options
);
torch
::
Tensor
outputs
=
torch
::
empty_like
(
inputs
,
act_options
);
// Input Linear Results Pointers to Q, K, and V of interviewed activations
void
*
q_lin_results_ptr
=
static_cast
<
void
*>
(
input_lin_results
.
data_ptr
());
void
*
k_lin_results_ptr
=
static_cast
<
void
*>
(
static_cast
<
half
*>
(
input_lin_results
.
data_ptr
())
+
head_dim
);
void
*
v_lin_results_ptr
=
static_cast
<
void
*>
(
static_cast
<
half
*>
(
input_lin_results
.
data_ptr
())
+
2
*
head_dim
);
// Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax)
void
*
bmm1_results_ptr
=
static_cast
<
void
*>
(
bmm1_results
.
data_ptr
());
void
*
dropout_results_ptr
=
static_cast
<
void
*>
(
dropout_results
.
data_ptr
());
char
a_layout_t
{
't'
};
char
a_layout_n
{
'n'
};
char
b_layout_n
{
'n'
};
rocblas_int
flags
=
0
;
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Input Linear Fwd
input_lin_results
.
copy_
(
input_biases
);
TORCH_CUDABLAS_CHECK
(
rocblas_gemm_ex
(
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
output_lin_dim
,
batches
,
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
input_weights
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
const
void
*>
(
inputs
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta_one
),
q_lin_results_ptr
,
rocblas_datatype_f16_r
,
output_lin_dim
,
q_lin_results_ptr
,
rocblas_datatype_f16_r
,
output_lin_dim
,
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum
(
a_layout_t
,
b_layout_n
,
k_seq_len
,
q_seq_len
,
head_dim
,
scale
,
static_cast
<
const
half
*>
(
k_lin_results_ptr
),
lead_dim
,
batch_stride
,
static_cast
<
const
half
*>
(
q_lin_results_ptr
),
lead_dim
,
batch_stride
,
beta_zero
,
static_cast
<
half
*>
(
bmm1_results_ptr
),
k_seq_len
,
k_seq_len
*
q_seq_len
,
static_cast
<
half
*>
(
bmm1_results_ptr
),
k_seq_len
,
k_seq_len
*
q_seq_len
,
attn_batches
,
flags
);
// Padded Softmax
bool
softmax_success
=
false
;
if
(
is_training
)
{
softmax_success
=
dispatch_additive_masked_softmax_dropout
<
half
,
half
,
float
>
(
reinterpret_cast
<
half
*>
(
dropout_results_ptr
),
(
is_training
)
?
reinterpret_cast
<
uint8_t
*>
(
dropout_mask
.
data_ptr
<
uint8_t
>
())
:
nullptr
,
reinterpret_cast
<
const
half
*>
(
bmm1_results_ptr
),
pad_mask
,
attn_batches
*
q_seq_len
*
q_seq_len
,
k_seq_len
,
k_seq_len
,
attn_batches
*
q_seq_len
,
attn_batches
*
q_seq_len
/
sequences
,
1.0
f
-
dropout_prob
,
stream
);
}
else
{
softmax_success
=
dispatch_additive_masked_softmax
<
half
,
half
,
float
>
(
reinterpret_cast
<
half
*>
(
dropout_results_ptr
),
// this is actually softmax results, but
// making it consistent for the next function
reinterpret_cast
<
const
half
*>
(
bmm1_results_ptr
),
pad_mask
,
k_seq_len
,
k_seq_len
,
attn_batches
*
q_seq_len
,
attn_batches
*
q_seq_len
/
sequences
);
}
// Matmul2
gemm_switch_fp32accum
(
a_layout_n
,
b_layout_n
,
head_dim
,
q_seq_len
,
k_seq_len
,
alpha
,
static_cast
<
const
half
*>
(
v_lin_results_ptr
),
lead_dim
,
batch_stride
,
static_cast
<
const
half
*>
(
dropout_results
.
data_ptr
()),
k_seq_len
,
k_seq_len
*
q_seq_len
,
beta_zero
,
static_cast
<
half
*>
(
matmul2_results
.
data_ptr
()),
head_dim
*
attn_batches
,
head_dim
,
static_cast
<
half
*>
(
matmul2_results
.
data_ptr
()),
head_dim
*
attn_batches
,
head_dim
,
attn_batches
,
flags
);
outputs
.
copy_
(
output_biases
);
// Output Linear
TORCH_CUDABLAS_CHECK
(
rocblas_gemm_ex
(
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
embed_dim
,
batches
,
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta_one
),
static_cast
<
void
*>
(
outputs
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
void
*>
(
outputs
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return
{
input_lin_results
,
bmm1_results
,
dropout_results
,
dropout_mask
,
matmul2_results
,
outputs
};
}
std
::
vector
<
torch
::
Tensor
>
bwd_cuda
(
int
heads
,
torch
::
Tensor
const
&
output_grads
,
torch
::
Tensor
const
&
matmul2_results
,
torch
::
Tensor
const
&
dropout_results
,
torch
::
Tensor
const
&
bmm1_results
,
torch
::
Tensor
const
&
pad_mask
,
torch
::
Tensor
const
&
input_lin_results
,
torch
::
Tensor
const
&
inputs
,
torch
::
Tensor
const
&
input_weights
,
torch
::
Tensor
const
&
output_weights
,
torch
::
Tensor
const
&
dropout_mask
,
float
dropout_prob
)
{
const
int
embed_dim
=
inputs
.
size
(
2
);
const
int
sequences
=
inputs
.
size
(
1
);
const
int
q_seq_len
=
inputs
.
size
(
0
);
const
int
k_seq_len
=
q_seq_len
;
const
int
batches
=
sequences
*
q_seq_len
;
const
int
head_dim
=
embed_dim
/
heads
;
const
int
output_lin_dim
=
3
*
embed_dim
;
const
int
attn_batches
=
heads
*
sequences
;
const
int
lead_dim
=
attn_batches
*
3
*
head_dim
;
const
int
batch_stride
=
3
*
head_dim
;
const
int
dropout_elems
=
attn_batches
*
q_seq_len
*
k_seq_len
;
const
float
alpha
=
1.0
;
const
float
beta
=
0.0
;
const
float
scale
=
1.0
/
sqrt
(
static_cast
<
float
>
(
head_dim
));
// TODO: Streams can be used in Backprop but I haven't added more than one
// in my first attempt to create the code
cublasHandle_t
handle
=
at
::
cuda
::
getCurrentCUDABlasHandle
();
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
cublasSetStream
(
handle
,
stream
);
// Output Tensor Allocations
torch
::
Tensor
input_grads
=
torch
::
empty_like
(
inputs
);
torch
::
Tensor
input_weight_grads
=
torch
::
empty_like
(
input_weights
);
torch
::
Tensor
output_weight_grads
=
torch
::
empty_like
(
output_weights
);
// Intermediate Tensor Allocations
at
::
Tensor
output_lin_grads
=
torch
::
empty_like
(
matmul2_results
);
at
::
Tensor
matmul2_grads
=
torch
::
empty_like
(
dropout_results
);
at
::
Tensor
input_lin_output_grads
=
torch
::
empty_like
(
input_lin_results
);
auto
q_lin_results_ptr
=
static_cast
<
half
*>
(
input_lin_results
.
data_ptr
());
auto
k_lin_results_ptr
=
static_cast
<
half
*>
(
input_lin_results
.
data_ptr
())
+
head_dim
;
auto
v_lin_results_ptr
=
static_cast
<
half
*>
(
input_lin_results
.
data_ptr
())
+
2
*
head_dim
;
auto
q_lin_grads_ptr
=
static_cast
<
half
*>
(
input_lin_output_grads
.
data_ptr
());
auto
k_lin_grads_ptr
=
static_cast
<
half
*>
(
input_lin_output_grads
.
data_ptr
())
+
head_dim
;
auto
v_lin_grads_ptr
=
static_cast
<
half
*>
(
input_lin_output_grads
.
data_ptr
())
+
2
*
head_dim
;
char
a_layout_n
{
'n'
};
char
a_layout_t
{
't'
};
char
b_layout_n
{
'n'
};
char
b_layout_t
{
't'
};
rocblas_int
flags
=
0
;
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
#ifdef __HIP_PLATFORM_HCC__
#define PYTORCH_ROCBLAS_VERSION_DECIMAL (ROCBLAS_VERSION_MAJOR * 100 + ROCBLAS_VERSION_MINOR)
#define USE_GEMM_FLAGS_FP16_ALT_IMPL (PYTORCH_ROCBLAS_VERSION_DECIMAL >= 242)
#if USE_GEMM_FLAGS_FP16_ALT_IMPL
#ifdef BACKWARD_PASS_GUARD
flags
=
at
::
BACKWARD_PASS_GUARD_CLASS
::
is_backward_pass
()
?
rocblas_gemm_flags_fp16_alt_impl
:
0
;
#endif
#endif
#endif
// Output Linear Dgrad
TORCH_CUDABLAS_CHECK
(
rocblas_gemm_ex
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
embed_dim
,
batches
,
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
const
void
*>
(
output_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
// Output Linear Wgrad
TORCH_CUDABLAS_CHECK
(
rocblas_gemm_ex
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
embed_dim
,
embed_dim
,
batches
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
const
void
*>
(
output_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
auto
output_bias_grads
=
output_grads
.
view
({
-
1
,
embed_dim
})
.
sum
(
0
,
false
);
// MatMul2 Dgrad1
gemm_switch_fp32accum
(
a_layout_t
,
b_layout_n
,
k_seq_len
,
q_seq_len
,
head_dim
,
alpha
,
static_cast
<
const
half
*>
(
v_lin_results_ptr
),
lead_dim
,
batch_stride
,
static_cast
<
const
half
*>
(
output_lin_grads
.
data_ptr
()),
head_dim
*
attn_batches
,
head_dim
,
beta
,
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
k_seq_len
,
k_seq_len
*
q_seq_len
,
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
k_seq_len
,
k_seq_len
*
q_seq_len
,
attn_batches
,
flags
);
// Matmul2 Dgrad2
gemm_switch_fp32accum
(
a_layout_n
,
b_layout_t
,
head_dim
,
k_seq_len
,
q_seq_len
,
alpha
,
static_cast
<
const
half
*>
(
output_lin_grads
.
data_ptr
()),
head_dim
*
attn_batches
,
head_dim
,
static_cast
<
const
half
*>
(
dropout_results
.
data_ptr
()),
k_seq_len
,
k_seq_len
*
q_seq_len
,
beta
,
v_lin_grads_ptr
,
lead_dim
,
batch_stride
,
v_lin_grads_ptr
,
lead_dim
,
batch_stride
,
attn_batches
,
flags
);
// Apply Dropout Mask and Scale by Dropout Probability
// Softmax Grad
dispatch_masked_scale_softmax_backward_recompute
<
half
,
half
,
float
,
false
>
(
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
static_cast
<
half
*
const
>
(
matmul2_grads
.
data_ptr
()),
reinterpret_cast
<
half
const
*>
(
bmm1_results
.
data_ptr
()),
reinterpret_cast
<
half
const
*>
(
pad_mask
.
data_ptr
()),
static_cast
<
uint8_t
const
*>
(
dropout_mask
.
data_ptr
()),
1.0
/
(
1.0
-
dropout_prob
),
k_seq_len
,
k_seq_len
,
attn_batches
*
q_seq_len
/
sequences
,
attn_batches
*
q_seq_len
,
stream
);
// Matmul1 Dgrad1
gemm_switch_fp32accum
(
a_layout_n
,
b_layout_n
,
head_dim
,
q_seq_len
,
k_seq_len
,
scale
,
k_lin_results_ptr
,
lead_dim
,
batch_stride
,
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
k_seq_len
,
k_seq_len
*
q_seq_len
,
beta
,
q_lin_grads_ptr
,
lead_dim
,
batch_stride
,
q_lin_grads_ptr
,
lead_dim
,
batch_stride
,
attn_batches
,
flags
);
// Matmul1 Dgrad2
gemm_switch_fp32accum
(
a_layout_n
,
b_layout_t
,
head_dim
,
k_seq_len
,
q_seq_len
,
scale
,
q_lin_results_ptr
,
lead_dim
,
batch_stride
,
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
k_seq_len
,
k_seq_len
*
q_seq_len
,
beta
,
k_lin_grads_ptr
,
lead_dim
,
batch_stride
,
k_lin_grads_ptr
,
lead_dim
,
batch_stride
,
attn_batches
,
flags
);
// Input Linear Dgrad
TORCH_CUDABLAS_CHECK
(
rocblas_gemm_ex
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
embed_dim
,
batches
,
output_lin_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
input_weights
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
const
void
*>
(
input_lin_output_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
output_lin_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
input_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
void
*>
(
input_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
// Input Linear Wgrad
TORCH_CUDABLAS_CHECK
(
rocblas_gemm_ex
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
embed_dim
,
output_lin_dim
,
batches
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
inputs
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
const
void
*>
(
q_lin_grads_ptr
),
rocblas_datatype_f16_r
,
output_lin_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
input_weight_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
void
*>
(
input_weight_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
auto
input_bias_grads
=
input_lin_output_grads
.
view
({
-
1
,
output_lin_dim
}).
sum
(
0
,
false
);
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return
{
input_grads
,
input_weight_grads
,
output_weight_grads
,
input_bias_grads
,
output_bias_grads
};
}
}
// end namespace rocblas_gemmex
}
// end namespace self_bias_additive_mask
}
// end namespace multihead_attn
apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_cuda.cu
deleted
100644 → 0
View file @
2a4864d5
#include <iostream>
#include <math.h>
#include <vector>
#include <cuda.h>
#include <cuda_fp16.h>
//#include <cuda_profiler_api.h>
#include <cuda_runtime.h>
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include "dropout.cuh"
#include "softmax.cuh"
#include "strided_batched_gemm.cuh"
namespace
multihead_attn
{
namespace
self_bias
{
namespace
rocblas_gemmex
{
std
::
vector
<
torch
::
Tensor
>
fwd_cuda
(
bool
use_time_mask
,
bool
is_training
,
int
heads
,
torch
::
Tensor
const
&
inputs
,
torch
::
Tensor
const
&
input_weights
,
torch
::
Tensor
const
&
output_weights
,
torch
::
Tensor
const
&
input_biases
,
torch
::
Tensor
const
&
output_biases
,
const
uint8_t
*
pad_mask
,
float
dropout_prob
)
{
const
int
embed_dim
=
inputs
.
size
(
2
);
const
int
sequences
=
inputs
.
size
(
1
);
const
int
q_seq_len
=
inputs
.
size
(
0
);
const
int
k_seq_len
=
q_seq_len
;
const
int
batches
=
sequences
*
q_seq_len
;
const
int
head_dim
=
embed_dim
/
heads
;
const
int
output_lin_dim
=
3
*
embed_dim
;
const
int
attn_batches
=
heads
*
sequences
;
const
int
lead_dim
=
attn_batches
*
3
*
head_dim
;
const
int
batch_stride
=
3
*
head_dim
;
const
int
dropout_elems
=
attn_batches
*
q_seq_len
*
k_seq_len
;
const
float
alpha
=
1.0
;
const
float
beta_zero
=
0.0
;
const
float
beta_one
=
1.0
;
const
float
scale
=
1.0
/
sqrt
(
static_cast
<
float
>
(
head_dim
));
// There is no reason to use more than one stream as every kernel is
// sequentially dependent
cublasHandle_t
handle
=
at
::
cuda
::
getCurrentCUDABlasHandle
();
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
cublasSetStream
(
handle
,
stream
);
// 3 Intermediate Results + Output (Note: dropout intermediates are generated
// by ATen library code)
auto
act_options
=
inputs
.
options
().
requires_grad
(
false
);
auto
mask_options
=
act_options
.
dtype
(
torch
::
kUInt8
);
torch
::
Tensor
input_lin_results
=
torch
::
empty
({
q_seq_len
,
sequences
,
output_lin_dim
},
act_options
);
torch
::
Tensor
softmax_results
=
torch
::
empty
({
attn_batches
,
q_seq_len
,
k_seq_len
},
act_options
);
torch
::
Tensor
dropout_results
=
torch
::
empty
({
attn_batches
,
q_seq_len
,
k_seq_len
},
act_options
);
torch
::
Tensor
dropout_mask
=
torch
::
empty
({
attn_batches
,
q_seq_len
,
k_seq_len
},
mask_options
);
torch
::
Tensor
matmul2_results
=
torch
::
empty
({
q_seq_len
,
attn_batches
,
head_dim
},
act_options
);
torch
::
Tensor
outputs
=
torch
::
empty_like
(
inputs
,
act_options
);
// Input Linear Results Pointers to Q, K, and V of interviewed activations
void
*
q_lin_results_ptr
=
static_cast
<
void
*>
(
input_lin_results
.
data_ptr
());
void
*
k_lin_results_ptr
=
static_cast
<
void
*>
(
static_cast
<
half
*>
(
input_lin_results
.
data_ptr
())
+
head_dim
);
void
*
v_lin_results_ptr
=
static_cast
<
void
*>
(
static_cast
<
half
*>
(
input_lin_results
.
data_ptr
())
+
2
*
head_dim
);
// Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax)
void
*
softmax_results_ptr
=
static_cast
<
void
*>
(
softmax_results
.
data_ptr
());
char
a_layout_t
{
't'
};
char
a_layout_n
{
'n'
};
char
b_layout_n
{
'n'
};
rocblas_int
flags
=
0
;
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Input Linear Fwd
input_lin_results
.
copy_
(
input_biases
);
TORCH_CUDABLAS_CHECK
(
rocblas_gemm_ex
(
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
output_lin_dim
,
batches
,
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
input_weights
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
const
void
*>
(
inputs
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta_one
),
q_lin_results_ptr
,
rocblas_datatype_f16_r
,
output_lin_dim
,
q_lin_results_ptr
,
rocblas_datatype_f16_r
,
output_lin_dim
,
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum
(
a_layout_t
,
b_layout_n
,
k_seq_len
,
q_seq_len
,
head_dim
,
scale
,
static_cast
<
const
half
*>
(
k_lin_results_ptr
),
lead_dim
,
batch_stride
,
static_cast
<
const
half
*>
(
q_lin_results_ptr
),
lead_dim
,
batch_stride
,
beta_zero
,
static_cast
<
half
*>
(
softmax_results_ptr
),
k_seq_len
,
k_seq_len
*
q_seq_len
,
static_cast
<
half
*>
(
softmax_results_ptr
),
k_seq_len
,
k_seq_len
*
q_seq_len
,
attn_batches
,
flags
);
// Padded Softmax
bool
softmax_success
=
false
;
if
(
pad_mask
==
nullptr
)
{
softmax_success
=
dispatch_softmax
<
half
,
half
,
float
>
(
reinterpret_cast
<
half
*>
(
softmax_results_ptr
),
reinterpret_cast
<
const
half
*>
(
softmax_results_ptr
),
k_seq_len
,
k_seq_len
,
attn_batches
*
q_seq_len
);
}
else
{
if
(
use_time_mask
)
{
softmax_success
=
dispatch_time_masked_softmax
<
half
,
half
,
float
>
(
reinterpret_cast
<
half
*>
(
softmax_results_ptr
),
reinterpret_cast
<
const
half
*>
(
softmax_results_ptr
),
pad_mask
,
k_seq_len
,
k_seq_len
,
attn_batches
*
q_seq_len
,
q_seq_len
);
}
else
{
softmax_success
=
dispatch_masked_softmax
<
half
,
half
,
float
>
(
reinterpret_cast
<
half
*>
(
softmax_results_ptr
),
reinterpret_cast
<
const
half
*>
(
softmax_results_ptr
),
pad_mask
,
k_seq_len
,
k_seq_len
,
attn_batches
*
q_seq_len
,
attn_batches
*
q_seq_len
/
sequences
);
}
}
if
(
is_training
)
{
// use at:: function so that C++ version generates the same random mask as
// python version
auto
dropout_tuple
=
at
::
_fused_dropout
(
softmax_results
,
1.0
f
-
dropout_prob
);
dropout_results
=
std
::
get
<
0
>
(
dropout_tuple
);
dropout_mask
=
std
::
get
<
1
>
(
dropout_tuple
);
}
// Matmul2
gemm_switch_fp32accum
(
a_layout_n
,
b_layout_n
,
head_dim
,
q_seq_len
,
k_seq_len
,
alpha
,
static_cast
<
const
half
*>
(
v_lin_results_ptr
),
lead_dim
,
batch_stride
,
(
is_training
)
?
static_cast
<
const
half
*>
(
dropout_results
.
data_ptr
())
:
static_cast
<
const
half
*>
(
softmax_results
.
data_ptr
())
,
k_seq_len
,
k_seq_len
*
q_seq_len
,
beta_zero
,
static_cast
<
half
*>
(
matmul2_results
.
data_ptr
()),
head_dim
*
attn_batches
,
head_dim
,
static_cast
<
half
*>
(
matmul2_results
.
data_ptr
()),
head_dim
*
attn_batches
,
head_dim
,
attn_batches
,
flags
);
outputs
.
copy_
(
output_biases
);
// Output Linear
TORCH_CUDABLAS_CHECK
(
rocblas_gemm_ex
(
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
embed_dim
,
batches
,
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta_one
),
static_cast
<
void
*>
(
outputs
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
void
*>
(
outputs
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return
{
input_lin_results
,
softmax_results
,
dropout_results
,
dropout_mask
,
matmul2_results
,
outputs
};
}
std
::
vector
<
torch
::
Tensor
>
bwd_cuda
(
int
heads
,
torch
::
Tensor
const
&
output_grads
,
torch
::
Tensor
const
&
matmul2_results
,
torch
::
Tensor
const
&
dropout_results
,
torch
::
Tensor
const
&
softmax_results
,
torch
::
Tensor
const
&
input_lin_results
,
torch
::
Tensor
const
&
inputs
,
torch
::
Tensor
const
&
input_weights
,
torch
::
Tensor
const
&
output_weights
,
torch
::
Tensor
const
&
dropout_mask
,
float
dropout_prob
)
{
const
int
embed_dim
=
inputs
.
size
(
2
);
const
int
sequences
=
inputs
.
size
(
1
);
const
int
q_seq_len
=
inputs
.
size
(
0
);
const
int
k_seq_len
=
q_seq_len
;
const
int
batches
=
sequences
*
q_seq_len
;
const
int
head_dim
=
embed_dim
/
heads
;
const
int
output_lin_dim
=
3
*
embed_dim
;
const
int
attn_batches
=
heads
*
sequences
;
const
int
lead_dim
=
attn_batches
*
3
*
head_dim
;
const
int
batch_stride
=
3
*
head_dim
;
const
int
dropout_elems
=
attn_batches
*
q_seq_len
*
k_seq_len
;
const
float
alpha
=
1.0
;
const
float
beta
=
0.0
;
const
float
scale
=
1.0
/
sqrt
(
static_cast
<
float
>
(
head_dim
));
// TODO: Streams can be used in Backprop but I haven't added more than one
// in my first attempt to create the code
cublasHandle_t
handle
=
at
::
cuda
::
getCurrentCUDABlasHandle
();
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
cublasSetStream
(
handle
,
stream
);
// Output Tensor Allocations
torch
::
Tensor
input_grads
=
torch
::
empty_like
(
inputs
);
torch
::
Tensor
input_weight_grads
=
torch
::
empty_like
(
input_weights
);
torch
::
Tensor
output_weight_grads
=
torch
::
empty_like
(
output_weights
);
// Intermediate Tensor Allocations
at
::
Tensor
output_lin_grads
=
torch
::
empty_like
(
matmul2_results
);
at
::
Tensor
matmul2_grads
=
torch
::
empty_like
(
dropout_results
);
at
::
Tensor
input_lin_output_grads
=
torch
::
empty_like
(
input_lin_results
);
auto
q_lin_results_ptr
=
static_cast
<
half
*>
(
input_lin_results
.
data_ptr
());
auto
k_lin_results_ptr
=
static_cast
<
half
*>
(
input_lin_results
.
data_ptr
())
+
head_dim
;
auto
v_lin_results_ptr
=
static_cast
<
half
*>
(
input_lin_results
.
data_ptr
())
+
2
*
head_dim
;
auto
q_lin_grads_ptr
=
static_cast
<
half
*>
(
input_lin_output_grads
.
data_ptr
());
auto
k_lin_grads_ptr
=
static_cast
<
half
*>
(
input_lin_output_grads
.
data_ptr
())
+
head_dim
;
auto
v_lin_grads_ptr
=
static_cast
<
half
*>
(
input_lin_output_grads
.
data_ptr
())
+
2
*
head_dim
;
char
a_layout_n
{
'n'
};
char
a_layout_t
{
't'
};
char
b_layout_n
{
'n'
};
char
b_layout_t
{
't'
};
rocblas_int
flags
=
0
;
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
#ifdef __HIP_PLATFORM_HCC__
#define PYTORCH_ROCBLAS_VERSION_DECIMAL (ROCBLAS_VERSION_MAJOR * 100 + ROCBLAS_VERSION_MINOR)
#define USE_GEMM_FLAGS_FP16_ALT_IMPL (PYTORCH_ROCBLAS_VERSION_DECIMAL >= 242)
#if USE_GEMM_FLAGS_FP16_ALT_IMPL
#ifdef BACKWARD_PASS_GUARD
flags
=
at
::
BACKWARD_PASS_GUARD_CLASS
::
is_backward_pass
()
?
rocblas_gemm_flags_fp16_alt_impl
:
0
;
#endif
#endif
#endif
// Output Linear Dgrad
TORCH_CUDABLAS_CHECK
(
rocblas_gemm_ex
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
embed_dim
,
batches
,
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
const
void
*>
(
output_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
// Output Linear Wgrad
TORCH_CUDABLAS_CHECK
(
rocblas_gemm_ex
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
embed_dim
,
embed_dim
,
batches
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
const
void
*>
(
output_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
auto
output_bias_grads
=
output_grads
.
view
({
-
1
,
embed_dim
})
.
sum
(
0
,
false
);
// MatMul2 Dgrad1
gemm_switch_fp32accum
(
a_layout_t
,
b_layout_n
,
k_seq_len
,
q_seq_len
,
head_dim
,
alpha
,
static_cast
<
const
half
*>
(
v_lin_results_ptr
),
lead_dim
,
batch_stride
,
static_cast
<
const
half
*>
(
output_lin_grads
.
data_ptr
()),
head_dim
*
attn_batches
,
head_dim
,
beta
,
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
k_seq_len
,
k_seq_len
*
q_seq_len
,
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
k_seq_len
,
k_seq_len
*
q_seq_len
,
attn_batches
,
flags
);
// Matmul2 Dgrad2
gemm_switch_fp32accum
(
a_layout_n
,
b_layout_t
,
head_dim
,
k_seq_len
,
q_seq_len
,
alpha
,
static_cast
<
const
half
*>
(
output_lin_grads
.
data_ptr
()),
head_dim
*
attn_batches
,
head_dim
,
static_cast
<
const
half
*>
(
dropout_results
.
data_ptr
()),
k_seq_len
,
k_seq_len
*
q_seq_len
,
beta
,
v_lin_grads_ptr
,
lead_dim
,
batch_stride
,
v_lin_grads_ptr
,
lead_dim
,
batch_stride
,
attn_batches
,
flags
);
// Apply Dropout Mask and Scale by Dropout Probability
// Softmax Grad
dispatch_masked_scale_softmax_backward_stream
<
half
,
half
,
float
,
false
>
(
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
reinterpret_cast
<
half
const
*>
(
softmax_results
.
data_ptr
()),
static_cast
<
uint8_t
const
*>
(
dropout_mask
.
data_ptr
()),
1.0
/
(
1.0
-
dropout_prob
),
k_seq_len
,
k_seq_len
,
attn_batches
*
q_seq_len
,
stream
);
// Matmul1 Dgrad1
gemm_switch_fp32accum
(
a_layout_n
,
b_layout_n
,
head_dim
,
q_seq_len
,
k_seq_len
,
scale
,
k_lin_results_ptr
,
lead_dim
,
batch_stride
,
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
k_seq_len
,
k_seq_len
*
q_seq_len
,
beta
,
q_lin_grads_ptr
,
lead_dim
,
batch_stride
,
q_lin_grads_ptr
,
lead_dim
,
batch_stride
,
attn_batches
,
flags
);
// Matmul1 Dgrad2
gemm_switch_fp32accum
(
a_layout_n
,
b_layout_t
,
head_dim
,
k_seq_len
,
q_seq_len
,
scale
,
q_lin_results_ptr
,
lead_dim
,
batch_stride
,
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
k_seq_len
,
k_seq_len
*
q_seq_len
,
beta
,
k_lin_grads_ptr
,
lead_dim
,
batch_stride
,
k_lin_grads_ptr
,
lead_dim
,
batch_stride
,
attn_batches
,
flags
);
// Input Linear Dgrad
TORCH_CUDABLAS_CHECK
(
rocblas_gemm_ex
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
embed_dim
,
batches
,
output_lin_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
input_weights
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
const
void
*>
(
input_lin_output_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
output_lin_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
input_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
void
*>
(
input_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
// Input Linear Wgrad
TORCH_CUDABLAS_CHECK
(
rocblas_gemm_ex
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
embed_dim
,
output_lin_dim
,
batches
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
inputs
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
const
void
*>
(
q_lin_grads_ptr
),
rocblas_datatype_f16_r
,
output_lin_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
input_weight_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
void
*>
(
input_weight_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
auto
input_bias_grads
=
input_lin_output_grads
.
view
({
-
1
,
output_lin_dim
}).
sum
(
0
,
false
);
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return
{
input_grads
,
input_weight_grads
,
output_weight_grads
,
input_bias_grads
,
output_bias_grads
};
}
}
// end namespace rocblas_gemmex
}
// end namespace self
}
// end namespace multihead_attn
\ No newline at end of file
apex/contrib/csrc/multihead_attn/self_multihead_attn_cuda.cu
deleted
100644 → 0
View file @
2a4864d5
#include <iostream>
#include <math.h>
#include <vector>
#include <cuda.h>
#include <cuda_fp16.h>
//#include <cuda_profiler_api.h>
#include <cuda_runtime.h>
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include "dropout.cuh"
#include "softmax.cuh"
#include "strided_batched_gemm.cuh"
namespace
multihead_attn
{
namespace
self
{
namespace
rocblas_gemmex
{
std
::
vector
<
torch
::
Tensor
>
fwd_cuda
(
bool
use_time_mask
,
bool
is_training
,
int
heads
,
torch
::
Tensor
const
&
inputs
,
torch
::
Tensor
const
&
input_weights
,
torch
::
Tensor
const
&
output_weights
,
const
uint8_t
*
pad_mask
,
float
dropout_prob
)
{
const
int
embed_dim
=
inputs
.
size
(
2
);
const
int
sequences
=
inputs
.
size
(
1
);
const
int
q_seq_len
=
inputs
.
size
(
0
);
const
int
k_seq_len
=
q_seq_len
;
const
int
batches
=
sequences
*
q_seq_len
;
const
int
head_dim
=
embed_dim
/
heads
;
const
int
output_lin_dim
=
3
*
embed_dim
;
const
int
attn_batches
=
heads
*
sequences
;
const
int
lead_dim
=
attn_batches
*
3
*
head_dim
;
const
int
batch_stride
=
3
*
head_dim
;
const
int
dropout_elems
=
attn_batches
*
q_seq_len
*
k_seq_len
;
const
float
alpha
=
1.0
;
const
float
beta
=
0.0
;
const
float
scale
=
1.0
/
sqrt
(
static_cast
<
float
>
(
head_dim
));
// There is no reason to use more than one stream as every kernel is
// sequentially dependent
cublasHandle_t
handle
=
at
::
cuda
::
getCurrentCUDABlasHandle
();
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
cublasSetStream
(
handle
,
stream
);
// 3 Intermediate Results + Output (Note: dropout intermediates are generated
// by ATen library code)
auto
act_options
=
inputs
.
options
().
requires_grad
(
false
);
auto
mask_options
=
act_options
.
dtype
(
torch
::
kUInt8
);
torch
::
Tensor
input_lin_results
=
torch
::
empty
({
q_seq_len
,
sequences
,
output_lin_dim
},
act_options
);
torch
::
Tensor
softmax_results
=
torch
::
empty
({
attn_batches
,
q_seq_len
,
k_seq_len
},
act_options
);
torch
::
Tensor
dropout_results
=
torch
::
empty
({
attn_batches
,
q_seq_len
,
k_seq_len
},
act_options
);
torch
::
Tensor
dropout_mask
=
torch
::
empty
({
attn_batches
,
q_seq_len
,
k_seq_len
},
mask_options
);
torch
::
Tensor
matmul2_results
=
torch
::
empty
({
q_seq_len
,
attn_batches
,
head_dim
},
act_options
);
torch
::
Tensor
outputs
=
torch
::
empty_like
(
inputs
,
act_options
);
// Input Linear Results Pointers to Q, K, and V of interviewed activations
void
*
q_lin_results_ptr
=
static_cast
<
void
*>
(
input_lin_results
.
data_ptr
());
void
*
k_lin_results_ptr
=
static_cast
<
void
*>
(
static_cast
<
half
*>
(
input_lin_results
.
data_ptr
())
+
head_dim
);
void
*
v_lin_results_ptr
=
static_cast
<
void
*>
(
static_cast
<
half
*>
(
input_lin_results
.
data_ptr
())
+
2
*
head_dim
);
// Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax)
void
*
softmax_results_ptr
=
static_cast
<
void
*>
(
softmax_results
.
data_ptr
());
char
a_layout_t
{
't'
};
char
a_layout_n
{
'n'
};
char
b_layout_n
{
'n'
};
rocblas_int
flags
=
0
;
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Input Linear Fwd
TORCH_CUDABLAS_CHECK
(
rocblas_gemm_ex
(
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
output_lin_dim
,
batches
,
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
input_weights
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
const
void
*>
(
inputs
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
q_lin_results_ptr
,
rocblas_datatype_f16_r
,
output_lin_dim
,
q_lin_results_ptr
,
rocblas_datatype_f16_r
,
output_lin_dim
,
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum
(
a_layout_t
,
b_layout_n
,
k_seq_len
,
q_seq_len
,
head_dim
,
scale
,
static_cast
<
const
half
*>
(
k_lin_results_ptr
),
lead_dim
,
batch_stride
,
static_cast
<
const
half
*>
(
q_lin_results_ptr
),
lead_dim
,
batch_stride
,
beta
,
static_cast
<
half
*>
(
softmax_results_ptr
),
k_seq_len
,
k_seq_len
*
q_seq_len
,
static_cast
<
half
*>
(
softmax_results_ptr
),
k_seq_len
,
k_seq_len
*
q_seq_len
,
attn_batches
,
flags
);
// Padded Softmax
bool
softmax_success
=
false
;
if
(
pad_mask
==
nullptr
)
{
softmax_success
=
dispatch_softmax
<
half
,
half
,
float
>
(
reinterpret_cast
<
half
*>
(
softmax_results_ptr
),
reinterpret_cast
<
const
half
*>
(
softmax_results_ptr
),
k_seq_len
,
k_seq_len
,
attn_batches
*
q_seq_len
);
}
else
{
if
(
use_time_mask
)
{
softmax_success
=
dispatch_time_masked_softmax
<
half
,
half
,
float
>
(
reinterpret_cast
<
half
*>
(
softmax_results_ptr
),
reinterpret_cast
<
const
half
*>
(
softmax_results_ptr
),
pad_mask
,
k_seq_len
,
k_seq_len
,
attn_batches
*
q_seq_len
,
q_seq_len
);
}
else
{
softmax_success
=
dispatch_masked_softmax
<
half
,
half
,
float
>
(
reinterpret_cast
<
half
*>
(
softmax_results_ptr
),
reinterpret_cast
<
const
half
*>
(
softmax_results_ptr
),
pad_mask
,
k_seq_len
,
k_seq_len
,
attn_batches
*
q_seq_len
,
attn_batches
*
q_seq_len
/
sequences
);
}
}
assert
(
softmax_success
);
if
(
is_training
)
{
apex_fused_dropout_cuda
<
at
::
Half
,
float
,
uint32_t
>
(
static_cast
<
at
::
Half
const
*>
(
softmax_results
.
data_ptr
()),
static_cast
<
at
::
Half
*>
(
dropout_results
.
data_ptr
()),
static_cast
<
uint8_t
*>
(
dropout_mask
.
data_ptr
()),
dropout_elems
,
(
1.0
f
-
dropout_prob
));
}
// Matmul2
gemm_switch_fp32accum
(
a_layout_n
,
b_layout_n
,
head_dim
,
q_seq_len
,
k_seq_len
,
alpha
,
static_cast
<
const
half
*>
(
v_lin_results_ptr
),
lead_dim
,
batch_stride
,
(
is_training
)
?
static_cast
<
const
half
*>
(
dropout_results
.
data_ptr
())
:
static_cast
<
const
half
*>
(
softmax_results
.
data_ptr
())
,
k_seq_len
,
k_seq_len
*
q_seq_len
,
beta
,
static_cast
<
half
*>
(
matmul2_results
.
data_ptr
()),
head_dim
*
attn_batches
,
head_dim
,
static_cast
<
half
*>
(
matmul2_results
.
data_ptr
()),
head_dim
*
attn_batches
,
head_dim
,
attn_batches
,
flags
);
// Output Linear
TORCH_CUDABLAS_CHECK
(
rocblas_gemm_ex
(
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
embed_dim
,
batches
,
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
outputs
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
void
*>
(
outputs
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return
{
input_lin_results
,
softmax_results
,
dropout_results
,
dropout_mask
,
matmul2_results
,
outputs
};
}
std
::
vector
<
torch
::
Tensor
>
bwd_cuda
(
int
heads
,
torch
::
Tensor
const
&
output_grads
,
torch
::
Tensor
const
&
matmul2_results
,
torch
::
Tensor
const
&
dropout_results
,
torch
::
Tensor
const
&
softmax_results
,
torch
::
Tensor
const
&
input_lin_results
,
torch
::
Tensor
const
&
inputs
,
torch
::
Tensor
const
&
input_weights
,
torch
::
Tensor
const
&
output_weights
,
torch
::
Tensor
const
&
dropout_mask
,
float
dropout_prob
)
{
const
int
embed_dim
=
inputs
.
size
(
2
);
const
int
sequences
=
inputs
.
size
(
1
);
const
int
q_seq_len
=
inputs
.
size
(
0
);
const
int
k_seq_len
=
q_seq_len
;
const
int
batches
=
sequences
*
q_seq_len
;
const
int
head_dim
=
embed_dim
/
heads
;
const
int
output_lin_dim
=
3
*
embed_dim
;
const
int
attn_batches
=
heads
*
sequences
;
const
int
lead_dim
=
attn_batches
*
3
*
head_dim
;
const
int
batch_stride
=
3
*
head_dim
;
const
int
dropout_elems
=
attn_batches
*
q_seq_len
*
k_seq_len
;
const
float
alpha
=
1.0
;
const
float
beta
=
0.0
;
const
float
scale
=
1.0
/
sqrt
(
static_cast
<
float
>
(
head_dim
));
// TODO: Streams can be used in Backprop but I haven't added more than one
// in my first attempt to create the code
cublasHandle_t
handle
=
at
::
cuda
::
getCurrentCUDABlasHandle
();
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
cublasSetStream
(
handle
,
stream
);
// Output Tensor Allocations
torch
::
Tensor
input_grads
=
torch
::
empty_like
(
inputs
);
torch
::
Tensor
input_weight_grads
=
torch
::
empty_like
(
input_weights
);
torch
::
Tensor
output_weight_grads
=
torch
::
empty_like
(
output_weights
);
// Intermediate Tensor Allocations
at
::
Tensor
output_lin_grads
=
torch
::
empty_like
(
matmul2_results
);
at
::
Tensor
matmul2_grads
=
torch
::
empty_like
(
dropout_results
);
at
::
Tensor
input_lin_output_grads
=
torch
::
empty_like
(
input_lin_results
);
auto
q_lin_results_ptr
=
static_cast
<
half
*>
(
input_lin_results
.
data_ptr
());
auto
k_lin_results_ptr
=
static_cast
<
half
*>
(
input_lin_results
.
data_ptr
())
+
head_dim
;
auto
v_lin_results_ptr
=
static_cast
<
half
*>
(
input_lin_results
.
data_ptr
())
+
2
*
head_dim
;
auto
q_lin_grads_ptr
=
static_cast
<
half
*>
(
input_lin_output_grads
.
data_ptr
());
auto
k_lin_grads_ptr
=
static_cast
<
half
*>
(
input_lin_output_grads
.
data_ptr
())
+
head_dim
;
auto
v_lin_grads_ptr
=
static_cast
<
half
*>
(
input_lin_output_grads
.
data_ptr
())
+
2
*
head_dim
;
char
a_layout_n
{
'n'
};
char
a_layout_t
{
't'
};
char
b_layout_n
{
'n'
};
char
b_layout_t
{
't'
};
rocblas_int
flags
=
0
;
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
#ifdef __HIP_PLATFORM_HCC__
#define PYTORCH_ROCBLAS_VERSION_DECIMAL (ROCBLAS_VERSION_MAJOR * 100 + ROCBLAS_VERSION_MINOR)
#define USE_GEMM_FLAGS_FP16_ALT_IMPL (PYTORCH_ROCBLAS_VERSION_DECIMAL >= 242)
#if USE_GEMM_FLAGS_FP16_ALT_IMPL
#ifdef BACKWARD_PASS_GUARD
flags
=
at
::
BACKWARD_PASS_GUARD_CLASS
::
is_backward_pass
()
?
rocblas_gemm_flags_fp16_alt_impl
:
0
;
#endif
#endif
#endif
// Output Linear Dgrad
TORCH_CUDABLAS_CHECK
(
rocblas_gemm_ex
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
embed_dim
,
batches
,
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
const
void
*>
(
output_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
// Output Linear Wgrad
TORCH_CUDABLAS_CHECK
(
rocblas_gemm_ex
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
embed_dim
,
embed_dim
,
batches
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
const
void
*>
(
output_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
// MatMul2 Dgrad1
gemm_switch_fp32accum
(
a_layout_t
,
b_layout_n
,
k_seq_len
,
q_seq_len
,
head_dim
,
alpha
,
static_cast
<
const
half
*>
(
v_lin_results_ptr
),
lead_dim
,
batch_stride
,
static_cast
<
const
half
*>
(
output_lin_grads
.
data_ptr
()),
head_dim
*
attn_batches
,
head_dim
,
beta
,
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
k_seq_len
,
k_seq_len
*
q_seq_len
,
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
k_seq_len
,
k_seq_len
*
q_seq_len
,
attn_batches
,
flags
);
// Matmul2 Dgrad2
gemm_switch_fp32accum
(
a_layout_n
,
b_layout_t
,
head_dim
,
k_seq_len
,
q_seq_len
,
alpha
,
static_cast
<
const
half
*>
(
output_lin_grads
.
data_ptr
()),
head_dim
*
attn_batches
,
head_dim
,
static_cast
<
const
half
*>
(
dropout_results
.
data_ptr
()),
k_seq_len
,
k_seq_len
*
q_seq_len
,
beta
,
v_lin_grads_ptr
,
lead_dim
,
batch_stride
,
v_lin_grads_ptr
,
lead_dim
,
batch_stride
,
attn_batches
,
flags
);
// Apply Dropout Mask and Scale by Dropout Probability
apex_masked_scale_cuda
<
at
::
Half
,
float
,
uint32_t
>
(
static_cast
<
at
::
Half
const
*>
(
matmul2_grads
.
data_ptr
()),
static_cast
<
at
::
Half
*>
(
matmul2_grads
.
data_ptr
()),
static_cast
<
uint8_t
const
*>
(
dropout_mask
.
data_ptr
()),
dropout_elems
,
(
1.0
/
(
1.0
-
dropout_prob
)));
// Softmax Grad
bool
softmax_success
=
false
;
softmax_success
=
dispatch_softmax_backward
<
half
,
half
,
float
>
(
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
reinterpret_cast
<
half
const
*>
(
softmax_results
.
data_ptr
()),
k_seq_len
,
k_seq_len
,
attn_batches
*
q_seq_len
);
assert
(
softmax_success
);
// Matmul1 Dgrad1
gemm_switch_fp32accum
(
a_layout_n
,
b_layout_n
,
head_dim
,
q_seq_len
,
k_seq_len
,
scale
,
k_lin_results_ptr
,
lead_dim
,
batch_stride
,
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
k_seq_len
,
k_seq_len
*
q_seq_len
,
beta
,
q_lin_grads_ptr
,
lead_dim
,
batch_stride
,
q_lin_grads_ptr
,
lead_dim
,
batch_stride
,
attn_batches
,
flags
);
// Matmul1 Dgrad2
gemm_switch_fp32accum
(
a_layout_n
,
b_layout_t
,
head_dim
,
k_seq_len
,
q_seq_len
,
scale
,
q_lin_results_ptr
,
lead_dim
,
batch_stride
,
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
k_seq_len
,
k_seq_len
*
q_seq_len
,
beta
,
k_lin_grads_ptr
,
lead_dim
,
batch_stride
,
k_lin_grads_ptr
,
lead_dim
,
batch_stride
,
attn_batches
,
flags
);
// Input Linear Dgrad
TORCH_CUDABLAS_CHECK
(
rocblas_gemm_ex
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
embed_dim
,
batches
,
output_lin_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
input_weights
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
const
void
*>
(
q_lin_grads_ptr
),
rocblas_datatype_f16_r
,
output_lin_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
input_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
void
*>
(
input_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
// Input Linear Wgrad
TORCH_CUDABLAS_CHECK
(
rocblas_gemm_ex
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
embed_dim
,
output_lin_dim
,
batches
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
inputs
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
const
void
*>
(
q_lin_grads_ptr
),
rocblas_datatype_f16_r
,
output_lin_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
input_weight_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
void
*>
(
input_weight_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return
{
input_grads
,
input_weight_grads
,
output_weight_grads
};
}
}
// end namespace rocblas_gemmex
}
// end namespace self
}
// end namespace multihead_attn
apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add_cuda.cu
deleted
100644 → 0
View file @
2a4864d5
#include <iostream>
#include <math.h>
#include <vector>
#include <cuda.h>
#include <cuda_fp16.h>
//#include <cuda_profiler_api.h>
#include <cuda_runtime.h>
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include "dropout.cuh"
#include "layer_norm.cuh"
#include "softmax.cuh"
#include "strided_batched_gemm.cuh"
namespace
multihead_attn
{
namespace
self_norm_add
{
namespace
rocblas_gemmex
{
std
::
vector
<
torch
::
Tensor
>
fwd_cuda
(
bool
use_time_mask
,
bool
is_training
,
int
heads
,
torch
::
Tensor
const
&
inputs
,
torch
::
Tensor
const
&
lyr_nrm_gamma_weights
,
torch
::
Tensor
const
&
lyr_nrm_beta_weights
,
torch
::
Tensor
const
&
input_weights
,
torch
::
Tensor
const
&
output_weights
,
const
uint8_t
*
pad_mask
,
float
dropout_prob
)
{
const
int
embed_dim
=
inputs
.
size
(
2
);
const
int
sequences
=
inputs
.
size
(
1
);
const
int
q_seq_len
=
inputs
.
size
(
0
);
const
int
k_seq_len
=
q_seq_len
;
const
int
batches
=
sequences
*
q_seq_len
;
const
int
total_tokens
=
batches
*
embed_dim
;
const
int
head_dim
=
embed_dim
/
heads
;
const
int
output_lin_dim
=
3
*
embed_dim
;
const
int
attn_batches
=
heads
*
sequences
;
const
int
lead_dim
=
attn_batches
*
3
*
head_dim
;
const
int
batch_stride
=
3
*
head_dim
;
const
int
dropout_elems
=
attn_batches
*
q_seq_len
*
k_seq_len
;
const
float
alpha
=
1.0
;
const
float
beta
=
0.0
;
const
float
scale
=
1.0
/
sqrt
(
static_cast
<
float
>
(
head_dim
));
// There is no reason to use more than one stream as every kernel is
// sequentially dependent
cublasHandle_t
handle
=
at
::
cuda
::
getCurrentCUDABlasHandle
();
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
cublasSetStream
(
handle
,
stream
);
// 3 Intermediate Results + Output (Note: dropout intermediates are generated
// by ATen library code)
auto
act_options
=
inputs
.
options
().
requires_grad
(
false
);
auto
lyr_nrm_options
=
act_options
.
dtype
(
torch
::
kFloat32
);
auto
mask_options
=
act_options
.
dtype
(
torch
::
kUInt8
);
torch
::
Tensor
lyr_nrm_mean
=
torch
::
empty
({
batches
},
lyr_nrm_options
);
torch
::
Tensor
lyr_nrm_invvar
=
torch
::
empty
({
batches
},
lyr_nrm_options
);
torch
::
Tensor
lyr_nrm_results
=
torch
::
empty_like
(
inputs
,
act_options
);
torch
::
Tensor
input_lin_results
=
torch
::
empty
({
q_seq_len
,
sequences
,
output_lin_dim
},
act_options
);
torch
::
Tensor
softmax_results
=
torch
::
empty
({
attn_batches
,
q_seq_len
,
k_seq_len
},
act_options
);
torch
::
Tensor
dropout_results
=
torch
::
empty
({
attn_batches
,
q_seq_len
,
k_seq_len
},
act_options
);
torch
::
Tensor
dropout_mask
=
torch
::
empty
({
attn_batches
,
q_seq_len
,
k_seq_len
},
mask_options
);
torch
::
Tensor
matmul2_results
=
torch
::
empty
({
q_seq_len
,
attn_batches
,
head_dim
},
act_options
);
torch
::
Tensor
output_lin_results
=
torch
::
empty_like
(
inputs
,
act_options
);
torch
::
Tensor
dropout_add_mask
=
torch
::
empty_like
(
inputs
,
mask_options
);
torch
::
Tensor
outputs
=
torch
::
empty_like
(
inputs
,
act_options
);
// Input Linear Results Pointers to Q, K, and V of interviewed activations
void
*
q_lin_results_ptr
=
static_cast
<
void
*>
(
input_lin_results
.
data_ptr
());
void
*
k_lin_results_ptr
=
static_cast
<
void
*>
(
static_cast
<
half
*>
(
input_lin_results
.
data_ptr
())
+
head_dim
);
void
*
v_lin_results_ptr
=
static_cast
<
void
*>
(
static_cast
<
half
*>
(
input_lin_results
.
data_ptr
())
+
2
*
head_dim
);
// Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax)
void
*
softmax_results_ptr
=
static_cast
<
void
*>
(
softmax_results
.
data_ptr
());
char
a_layout_t
{
't'
};
char
a_layout_n
{
'n'
};
char
b_layout_n
{
'n'
};
rocblas_int
flags
=
0
;
//THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Layer Norm
HostApplyLayerNorm
<
at
::
Half
,
float
>
(
static_cast
<
at
::
Half
*>
(
lyr_nrm_results
.
data_ptr
()),
static_cast
<
float
*>
(
lyr_nrm_mean
.
data_ptr
()),
static_cast
<
float
*>
(
lyr_nrm_invvar
.
data_ptr
()),
static_cast
<
const
at
::
Half
*>
(
inputs
.
data_ptr
()),
static_cast
<
int
>
(
batches
),
// n1
static_cast
<
int
>
(
embed_dim
),
// n2
1.0e-5
,
static_cast
<
const
at
::
Half
*>
(
lyr_nrm_gamma_weights
.
data_ptr
()),
static_cast
<
const
at
::
Half
*>
(
lyr_nrm_beta_weights
.
data_ptr
()));
// Input Linear Fwd
TORCH_CUDABLAS_CHECK
(
rocblas_gemm_ex
(
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
output_lin_dim
,
batches
,
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
input_weights
.
data_ptr
()),
rocblas_datatype_f16_r
/*a_type*/
,
embed_dim
,
//static_cast<const void*>(inputs.data_ptr()),
static_cast
<
const
void
*>
(
lyr_nrm_results
.
data_ptr
()),
rocblas_datatype_f16_r
/*b_type*/
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
q_lin_results_ptr
,
rocblas_datatype_f16_r
/*c_type*/
,
output_lin_dim
,
q_lin_results_ptr
,
rocblas_datatype_f16_r
/*d_type*/
,
output_lin_dim
,
rocblas_datatype_f32_r
/*compute_type*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum
(
a_layout_t
,
b_layout_n
,
k_seq_len
,
q_seq_len
,
head_dim
,
scale
,
static_cast
<
const
half
*>
(
k_lin_results_ptr
),
lead_dim
,
batch_stride
,
static_cast
<
const
half
*>
(
q_lin_results_ptr
),
lead_dim
,
batch_stride
,
beta
,
static_cast
<
half
*>
(
softmax_results_ptr
),
k_seq_len
,
k_seq_len
*
q_seq_len
,
static_cast
<
half
*>
(
softmax_results_ptr
),
k_seq_len
,
k_seq_len
*
q_seq_len
,
attn_batches
,
flags
);
// Padded Softmax
bool
softmax_success
=
false
;
if
(
pad_mask
==
nullptr
)
{
softmax_success
=
dispatch_softmax
<
half
,
half
,
float
>
(
reinterpret_cast
<
half
*>
(
softmax_results_ptr
),
reinterpret_cast
<
const
half
*>
(
softmax_results_ptr
),
k_seq_len
,
k_seq_len
,
attn_batches
*
q_seq_len
);
}
else
{
if
(
use_time_mask
)
{
softmax_success
=
dispatch_time_masked_softmax
<
half
,
half
,
float
>
(
reinterpret_cast
<
half
*>
(
softmax_results_ptr
),
reinterpret_cast
<
const
half
*>
(
softmax_results_ptr
),
pad_mask
,
k_seq_len
,
k_seq_len
,
attn_batches
*
q_seq_len
,
q_seq_len
);
}
else
{
softmax_success
=
dispatch_masked_softmax
<
half
,
half
,
float
>
(
reinterpret_cast
<
half
*>
(
softmax_results_ptr
),
reinterpret_cast
<
const
half
*>
(
softmax_results_ptr
),
pad_mask
,
k_seq_len
,
k_seq_len
,
attn_batches
*
q_seq_len
,
attn_batches
*
q_seq_len
/
sequences
);
}
}
assert
(
softmax_success
);
if
(
is_training
)
{
apex_fused_dropout_cuda
<
at
::
Half
,
float
,
uint32_t
>
(
static_cast
<
at
::
Half
const
*>
(
softmax_results
.
data_ptr
()),
static_cast
<
at
::
Half
*>
(
dropout_results
.
data_ptr
()),
static_cast
<
uint8_t
*>
(
dropout_mask
.
data_ptr
()),
dropout_elems
,
(
1.0
f
-
dropout_prob
));
}
// Matmul2
gemm_switch_fp32accum
(
a_layout_n
,
b_layout_n
,
head_dim
,
q_seq_len
,
k_seq_len
,
alpha
,
static_cast
<
const
half
*>
(
v_lin_results_ptr
),
lead_dim
,
batch_stride
,
(
is_training
)
?
static_cast
<
const
half
*>
(
dropout_results
.
data_ptr
())
:
static_cast
<
const
half
*>
(
softmax_results
.
data_ptr
())
,
//static_cast<const half*>(dropout_results.data_ptr()),
k_seq_len
,
k_seq_len
*
q_seq_len
,
beta
,
static_cast
<
half
*>
(
matmul2_results
.
data_ptr
()),
head_dim
*
attn_batches
,
head_dim
,
static_cast
<
half
*>
(
matmul2_results
.
data_ptr
()),
head_dim
*
attn_batches
,
head_dim
,
attn_batches
,
flags
);
// Output Linear
TORCH_CUDABLAS_CHECK
(
rocblas_gemm_ex
(
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
embed_dim
,
batches
,
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
rocblas_datatype_f16_r
/*a_type*/
,
embed_dim
,
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
rocblas_datatype_f16_r
/*b_type*/
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
output_lin_results
.
data_ptr
()),
rocblas_datatype_f16_r
/*c_type*/
,
embed_dim
,
static_cast
<
void
*>
(
output_lin_results
.
data_ptr
()),
rocblas_datatype_f16_r
/*d_type*/
,
embed_dim
,
rocblas_datatype_f32_r
/*compute_type*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
// End-of-block Dropout-Add
if
(
is_training
)
{
apex_dropout_add_cuda
<
at
::
Half
,
float
,
uint32_t
>
(
static_cast
<
at
::
Half
const
*>
(
output_lin_results
.
data_ptr
()),
static_cast
<
at
::
Half
const
*>
(
inputs
.
data_ptr
()),
static_cast
<
at
::
Half
*>
(
outputs
.
data_ptr
()),
static_cast
<
uint8_t
*>
(
dropout_add_mask
.
data_ptr
()),
total_tokens
,
(
1.0
f
-
dropout_prob
));
}
else
{
apex_add_cuda
<
at
::
Half
,
float
,
uint32_t
>
(
static_cast
<
at
::
Half
const
*>
(
output_lin_results
.
data_ptr
()),
static_cast
<
at
::
Half
const
*>
(
inputs
.
data_ptr
()),
static_cast
<
at
::
Half
*>
(
outputs
.
data_ptr
()),
total_tokens
);
}
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return
{
lyr_nrm_results
,
lyr_nrm_mean
,
lyr_nrm_invvar
,
input_lin_results
,
softmax_results
,
dropout_results
,
dropout_mask
,
matmul2_results
,
dropout_add_mask
,
outputs
};
}
std
::
vector
<
torch
::
Tensor
>
bwd_cuda
(
int
heads
,
torch
::
Tensor
const
&
output_grads
,
torch
::
Tensor
const
&
matmul2_results
,
torch
::
Tensor
const
&
dropout_results
,
torch
::
Tensor
const
&
softmax_results
,
torch
::
Tensor
const
&
input_lin_results
,
torch
::
Tensor
const
&
lyr_nrm_results
,
torch
::
Tensor
const
&
lyr_nrm_mean
,
torch
::
Tensor
const
&
lyr_nrm_invvar
,
torch
::
Tensor
const
&
inputs
,
torch
::
Tensor
const
&
lyr_nrm_gamma_weights
,
torch
::
Tensor
const
&
lyr_nrm_beta_weights
,
torch
::
Tensor
const
&
input_weights
,
torch
::
Tensor
const
&
output_weights
,
torch
::
Tensor
const
&
dropout_mask
,
torch
::
Tensor
const
&
dropout_add_mask
,
float
dropout_prob
)
{
const
int
embed_dim
=
inputs
.
size
(
2
);
const
int
sequences
=
inputs
.
size
(
1
);
const
int
q_seq_len
=
inputs
.
size
(
0
);
const
int
k_seq_len
=
q_seq_len
;
const
int
batches
=
sequences
*
q_seq_len
;
const
int
total_tokens
=
batches
*
embed_dim
;
const
int
head_dim
=
embed_dim
/
heads
;
const
int
output_lin_dim
=
3
*
embed_dim
;
const
int
attn_batches
=
heads
*
sequences
;
const
int
lead_dim
=
attn_batches
*
3
*
head_dim
;
const
int
batch_stride
=
3
*
head_dim
;
const
int
dropout_elems
=
attn_batches
*
q_seq_len
*
k_seq_len
;
const
float
alpha
=
1.0
;
const
float
beta
=
0.0
;
const
float
scale
=
1.0
/
sqrt
(
static_cast
<
float
>
(
head_dim
));
// TODO: Streams can be used in Backprop but I haven't added more than one
// in my first attempt to create the code
cublasHandle_t
handle
=
at
::
cuda
::
getCurrentCUDABlasHandle
();
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
cublasSetStream
(
handle
,
stream
);
// Output Tensor Allocations
torch
::
Tensor
input_grads
=
torch
::
empty_like
(
inputs
);
torch
::
Tensor
lyr_nrm_gamma_grads
=
torch
::
empty_like
(
lyr_nrm_gamma_weights
);
torch
::
Tensor
lyr_nrm_beta_grads
=
torch
::
empty_like
(
lyr_nrm_beta_weights
);
torch
::
Tensor
input_weight_grads
=
torch
::
empty_like
(
input_weights
);
torch
::
Tensor
output_weight_grads
=
torch
::
empty_like
(
output_weights
);
// Intermediate Tensor Allocations
torch
::
Tensor
dropout_add_grads
=
torch
::
empty_like
(
output_grads
);
torch
::
Tensor
output_lin_grads
=
torch
::
empty_like
(
matmul2_results
);
torch
::
Tensor
matmul2_grads
=
torch
::
empty_like
(
dropout_results
);
torch
::
Tensor
input_lin_output_grads
=
torch
::
empty_like
(
input_lin_results
);
torch
::
Tensor
input_lin_grads
=
torch
::
empty_like
(
inputs
);
auto
q_lin_results_ptr
=
static_cast
<
half
*>
(
input_lin_results
.
data_ptr
());
auto
k_lin_results_ptr
=
static_cast
<
half
*>
(
input_lin_results
.
data_ptr
())
+
head_dim
;
auto
v_lin_results_ptr
=
static_cast
<
half
*>
(
input_lin_results
.
data_ptr
())
+
2
*
head_dim
;
auto
q_lin_grads_ptr
=
static_cast
<
half
*>
(
input_lin_output_grads
.
data_ptr
());
auto
k_lin_grads_ptr
=
static_cast
<
half
*>
(
input_lin_output_grads
.
data_ptr
())
+
head_dim
;
auto
v_lin_grads_ptr
=
static_cast
<
half
*>
(
input_lin_output_grads
.
data_ptr
())
+
2
*
head_dim
;
char
a_layout_n
{
'n'
};
char
a_layout_t
{
't'
};
char
b_layout_n
{
'n'
};
char
b_layout_t
{
't'
};
rocblas_int
flags
=
0
;
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
#ifdef __HIP_PLATFORM_HCC__
#define PYTORCH_ROCBLAS_VERSION_DECIMAL (ROCBLAS_VERSION_MAJOR * 100 + ROCBLAS_VERSION_MINOR)
#define USE_GEMM_FLAGS_FP16_ALT_IMPL (PYTORCH_ROCBLAS_VERSION_DECIMAL >= 242)
#if USE_GEMM_FLAGS_FP16_ALT_IMPL
#ifdef BACKWARD_PASS_GUARD
flags
=
at
::
BACKWARD_PASS_GUARD_CLASS
::
is_backward_pass
()
?
rocblas_gemm_flags_fp16_alt_impl
:
0
;
#endif
#endif
#endif
// Dropout Add Backward
apex_masked_scale_cuda
<
at
::
Half
,
float
,
uint32_t
>
(
static_cast
<
at
::
Half
const
*>
(
output_grads
.
data_ptr
()),
static_cast
<
at
::
Half
*>
(
dropout_add_grads
.
data_ptr
()),
static_cast
<
uint8_t
const
*>
(
dropout_add_mask
.
data_ptr
()),
total_tokens
,
(
1.0
/
(
1.0
-
dropout_prob
)));
// Output Linear Dgrad
TORCH_CUDABLAS_CHECK
(
rocblas_gemm_ex
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
embed_dim
,
batches
,
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
rocblas_datatype_f16_r
/*a_type*/
,
embed_dim
,
static_cast
<
const
void
*>
(
dropout_add_grads
.
data_ptr
()),
rocblas_datatype_f16_r
/*b_type*/
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
rocblas_datatype_f16_r
/*c_type*/
,
embed_dim
,
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
rocblas_datatype_f16_r
/*d_type*/
,
embed_dim
,
rocblas_datatype_f32_r
/*compute_type*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
// Output Linear Wgrad
TORCH_CUDABLAS_CHECK
(
rocblas_gemm_ex
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
embed_dim
,
embed_dim
,
batches
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
rocblas_datatype_f16_r
/*a_type*/
,
embed_dim
,
static_cast
<
const
void
*>
(
dropout_add_grads
.
data_ptr
()),
rocblas_datatype_f16_r
/*b_type*/
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
rocblas_datatype_f16_r
/*c_type*/
,
embed_dim
,
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
rocblas_datatype_f16_r
/*d_type*/
,
embed_dim
,
rocblas_datatype_f32_r
/*compute_type*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
// MatMul2 Dgrad1
gemm_switch_fp32accum
(
a_layout_t
,
b_layout_n
,
k_seq_len
,
q_seq_len
,
head_dim
,
alpha
,
static_cast
<
const
half
*>
(
v_lin_results_ptr
),
lead_dim
,
batch_stride
,
static_cast
<
const
half
*>
(
output_lin_grads
.
data_ptr
()),
head_dim
*
attn_batches
,
head_dim
,
beta
,
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
k_seq_len
,
k_seq_len
*
q_seq_len
,
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
k_seq_len
,
k_seq_len
*
q_seq_len
,
attn_batches
,
flags
);
// Matmul2 Dgrad2
gemm_switch_fp32accum
(
a_layout_n
,
b_layout_t
,
head_dim
,
k_seq_len
,
q_seq_len
,
alpha
,
static_cast
<
const
half
*>
(
output_lin_grads
.
data_ptr
()),
head_dim
*
attn_batches
,
head_dim
,
static_cast
<
const
half
*>
(
dropout_results
.
data_ptr
()),
k_seq_len
,
k_seq_len
*
q_seq_len
,
beta
,
v_lin_grads_ptr
,
lead_dim
,
batch_stride
,
v_lin_grads_ptr
,
lead_dim
,
batch_stride
,
attn_batches
,
flags
);
// Apply Dropout Mask and Scale by Dropout Probability
apex_masked_scale_cuda
<
at
::
Half
,
float
,
uint32_t
>
(
static_cast
<
at
::
Half
const
*>
(
matmul2_grads
.
data_ptr
()),
static_cast
<
at
::
Half
*>
(
matmul2_grads
.
data_ptr
()),
static_cast
<
uint8_t
const
*>
(
dropout_mask
.
data_ptr
()),
dropout_elems
,
(
1.0
/
(
1.0
-
dropout_prob
)));
// Softmax Grad
bool
softmax_success
=
false
;
softmax_success
=
dispatch_softmax_backward
<
half
,
half
,
float
>
(
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
reinterpret_cast
<
half
const
*>
(
softmax_results
.
data_ptr
()),
k_seq_len
,
k_seq_len
,
attn_batches
*
q_seq_len
);
assert
(
softmax_success
);
// Matmul1 Dgrad1
gemm_switch_fp32accum
(
a_layout_n
,
b_layout_n
,
head_dim
,
q_seq_len
,
k_seq_len
,
scale
,
k_lin_results_ptr
,
lead_dim
,
batch_stride
,
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
k_seq_len
,
k_seq_len
*
q_seq_len
,
beta
,
q_lin_grads_ptr
,
lead_dim
,
batch_stride
,
q_lin_grads_ptr
,
lead_dim
,
batch_stride
,
attn_batches
,
flags
);
// Matmul1 Dgrad2
gemm_switch_fp32accum
(
a_layout_n
,
b_layout_t
,
head_dim
,
k_seq_len
,
q_seq_len
,
scale
,
q_lin_results_ptr
,
lead_dim
,
batch_stride
,
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
k_seq_len
,
k_seq_len
*
q_seq_len
,
beta
,
k_lin_grads_ptr
,
lead_dim
,
batch_stride
,
k_lin_grads_ptr
,
lead_dim
,
batch_stride
,
attn_batches
,
flags
);
// Input Linear Dgrad
TORCH_CUDABLAS_CHECK
(
rocblas_gemm_ex
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
embed_dim
,
batches
,
output_lin_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
input_weights
.
data_ptr
()),
rocblas_datatype_f16_r
/*a_type*/
,
embed_dim
,
static_cast
<
const
void
*>
(
q_lin_grads_ptr
),
rocblas_datatype_f16_r
/*b_type*/
,
output_lin_dim
,
static_cast
<
const
void
*>
(
&
beta
),
//static_cast<void*>(input_grads.data_ptr()),
static_cast
<
void
*>
(
input_lin_grads
.
data_ptr
()),
rocblas_datatype_f16_r
/*c_type*/
,
embed_dim
,
static_cast
<
void
*>
(
input_lin_grads
.
data_ptr
()),
rocblas_datatype_f16_r
/*d_type*/
,
embed_dim
,
rocblas_datatype_f32_r
/*compute_type*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
// Input Linear Wgrad
TORCH_CUDABLAS_CHECK
(
rocblas_gemm_ex
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
embed_dim
,
output_lin_dim
,
batches
,
static_cast
<
const
void
*>
(
&
alpha
),
//static_cast<const void*>(inputs.data_ptr()),
static_cast
<
const
void
*>
(
lyr_nrm_results
.
data_ptr
()),
rocblas_datatype_f16_r
/*a_type*/
,
embed_dim
,
static_cast
<
const
void
*>
(
q_lin_grads_ptr
),
rocblas_datatype_f16_r
/*b_type*/
,
output_lin_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
input_weight_grads
.
data_ptr
()),
rocblas_datatype_f16_r
/*c_type*/
,
embed_dim
,
static_cast
<
void
*>
(
input_weight_grads
.
data_ptr
()),
rocblas_datatype_f16_r
/*d_type*/
,
embed_dim
,
rocblas_datatype_f32_r
/*compute_type*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
// Fused Layer Norm Bwd with Residual Add
HostLayerNormGradient
<
half
,
float
>
(
static_cast
<
const
half
*>
(
input_lin_grads
.
data_ptr
()),
static_cast
<
const
half
*>
(
output_grads
.
data_ptr
()),
static_cast
<
const
float
*>
(
lyr_nrm_mean
.
data_ptr
()),
static_cast
<
const
float
*>
(
lyr_nrm_invvar
.
data_ptr
()),
inputs
,
static_cast
<
int
>
(
batches
),
// n1
static_cast
<
int
>
(
embed_dim
),
// n2
static_cast
<
const
half
*>
(
lyr_nrm_gamma_weights
.
data_ptr
()),
static_cast
<
const
half
*>
(
lyr_nrm_beta_weights
.
data_ptr
()),
1.0e-5
,
static_cast
<
half
*>
(
input_grads
.
data_ptr
()),
static_cast
<
half
*>
(
lyr_nrm_gamma_grads
.
data_ptr
()),
static_cast
<
half
*>
(
lyr_nrm_beta_grads
.
data_ptr
()));
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return
{
input_grads
,
lyr_nrm_gamma_grads
,
lyr_nrm_beta_grads
,
input_weight_grads
,
output_weight_grads
};
}
}
// end namespace rocblas_gemmex
}
// end namespace self_norm_add
}
// end namespace multihead_attn
\ No newline at end of file
apex/contrib/csrc/multihead_attn/softmax.cuh
deleted
100644 → 0
View file @
2a4864d5
#pragma once
#include "philox.cuh"
#include <ATen/cuda/CUDAGraphsUtils.cuh>
#include <curand_kernel.h>
#ifdef OLD_GENERATOR_PATH
#include <ATen/CUDAGeneratorImpl.h>
#else
#include <ATen/cuda/CUDAGeneratorImpl.h>
#endif
#include <assert.h>
#include <cfloat>
#include <cmath>
#include <cuda_fp16.h>
#include <limits>
#include <stdint.h>
#include <cuda_fp16.h>
#include <cmath>
#ifdef __HIP_PLATFORM_HCC__
#define APEX_WARP_SHFL_XOR(mask, value, offset, width) __shfl_xor(value, offset, width)
#else
#define APEX_WARP_SHFL_XOR __shfl_xor_sync
#endif
namespace
{
template
<
typename
Datatype
,
int
ELEMENTS_PER_LDG
>
__device__
__inline__
void
copy_vector
(
Datatype
*
dst
,
const
Datatype
*
src
);
template
<
typename
Datatype
,
int
ELEMENTS_PER_LDG
>
__device__
__inline__
void
apply_mask
(
Datatype
*
dst
,
Datatype
value
,
const
uint8_t
*
src
);
template
<
typename
Datatype
,
int
ELEMENTS_PER_LDG
>
__device__
__inline__
void
apply_additive_mask
(
Datatype
*
dst
,
const
Datatype
*
additive_mask
);
template
<
>
__device__
__inline__
void
copy_vector
<
__half
,
1
>
(
__half
*
dst
,
const
__half
*
src
)
{
*
dst
=
*
src
;
}
template
<
>
__device__
__inline__
void
copy_vector
<
float
,
1
>
(
float
*
dst
,
const
float
*
src
)
{
*
dst
=
*
src
;
}
template
<
>
__device__
__inline__
void
copy_vector
<
__half
,
4
>
(
__half
*
dst
,
const
__half
*
src
)
{
*
((
float2
*
)
dst
)
=
*
((
float2
*
)
src
);
}
template
<
>
__device__
__inline__
void
copy_vector
<
uint8_t
,
1
>
(
uint8_t
*
dst
,
const
uint8_t
*
src
)
{
*
dst
=
*
src
;
}
template
<
>
__device__
__inline__
void
copy_vector
<
uint8_t
,
4
>
(
uint8_t
*
dst
,
const
uint8_t
*
src
)
{
*
((
half2
*
)
dst
)
=
*
((
half2
*
)
src
);
}
template
<
>
__device__
__inline__
void
apply_mask
<
__half
,
1
>
(
__half
*
dst
,
__half
value
,
const
uint8_t
*
src
)
{
if
(
*
src
==
1
)
{
*
dst
=
value
;
}
}
template
<
>
__device__
__inline__
void
apply_additive_mask
<
__half
,
1
>
(
__half
*
dst
,
const
__half
*
additive_mask
)
{
*
dst
+=
*
additive_mask
;
}
template
<
>
__device__
__inline__
void
apply_additive_mask
<
__half
,
4
>
(
__half
*
dst
,
const
__half
*
additive_mask
)
{
*
dst
+=
*
additive_mask
;
*
(
dst
+
1
)
+=
*
(
additive_mask
+
1
);
*
(
dst
+
2
)
+=
*
(
additive_mask
+
2
);
*
(
dst
+
3
)
+=
*
(
additive_mask
+
3
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
// Warp Softmax forward
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
// WARP_BATCH number of batches.
// WARP_ITERATOINS The number of iterations required for one warp to iterate
// over all data. WARP_SIZE number of elements working on a single batch, has to
// be a power of two. ELEMENTS_PER_LDG_STG has to be 1.
template
<
typename
input_t
,
typename
output_t
,
typename
acc_t
,
int
WARP_BATCH
,
int
WARP_ITERATIONS
,
int
WARP_SIZE
=
32
,
int
ELEMENTS_PER_LDG_STG
=
1
>
__global__
void
softmax_warp_forward
(
input_t
*
dst
,
const
output_t
*
src
,
int
batch_size
,
int
stride
,
int
element_count
)
{
assert
(
ELEMENTS_PER_LDG_STG
==
1
);
int
first_batch
=
(
blockDim
.
y
*
blockIdx
.
x
+
threadIdx
.
y
)
*
WARP_BATCH
;
// batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP.
int
local_batches
=
batch_size
-
first_batch
;
if
(
local_batches
>
WARP_BATCH
)
local_batches
=
WARP_BATCH
;
// there might be multiple batches per warp. compute the index within the
// batch
int
local_idx
=
threadIdx
.
x
;
src
+=
first_batch
*
stride
+
ELEMENTS_PER_LDG_STG
*
local_idx
;
dst
+=
first_batch
*
stride
+
ELEMENTS_PER_LDG_STG
*
local_idx
;
// load data from global memory
input_t
elements_input
[
WARP_BATCH
][
WARP_ITERATIONS
];
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
int
batch_element_count
=
(
i
>=
local_batches
)
?
0
:
element_count
;
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
it
+=
ELEMENTS_PER_LDG_STG
)
{
int
element_index
=
ELEMENTS_PER_LDG_STG
*
local_idx
+
it
*
WARP_SIZE
;
#pragma unroll
for
(
int
element
=
0
;
element
<
ELEMENTS_PER_LDG_STG
;
++
element
)
{
elements_input
[
i
][
it
+
element
]
=
-
std
::
numeric_limits
<
float
>::
infinity
();
}
if
(
element_index
<
batch_element_count
)
{
copy_vector
<
input_t
,
ELEMENTS_PER_LDG_STG
>
(
&
elements_input
[
i
][
it
],
src
+
i
*
element_count
+
it
*
WARP_SIZE
);
}
}
}
// convert input_t to acc_t
acc_t
elements
[
WARP_BATCH
][
WARP_ITERATIONS
];
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
++
it
)
{
elements
[
i
][
it
]
=
elements_input
[
i
][
it
];
}
}
constexpr
uint32_t
FULL_MASK
=
0xffffffff
;
// compute local max_value
// take the max_value of the first element to avoid one max call
acc_t
max_value
[
WARP_BATCH
];
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
max_value
[
i
]
=
elements
[
i
][
0
];
}
#pragma unroll
for
(
int
it
=
1
;
it
<
WARP_ITERATIONS
;
++
it
)
{
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
max_value
[
i
]
=
(
max_value
[
i
]
>
elements
[
i
][
it
])
?
max_value
[
i
]
:
elements
[
i
][
it
];
}
}
// reduction max_value
#pragma unroll
for
(
int
offset
=
WARP_SIZE
/
2
;
offset
>
0
;
offset
/=
2
)
{
float
val
[
WARP_BATCH
];
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
val
[
i
]
=
APEX_WARP_SHFL_XOR
(
FULL_MASK
,
max_value
[
i
],
offset
,
WARP_SIZE
);
}
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
max_value
[
i
]
=
max_value
[
i
]
>
val
[
i
]
?
max_value
[
i
]
:
val
[
i
];
}
}
// compute local sum
acc_t
sum
[
WARP_BATCH
]{
0.0
f
};
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
++
it
)
{
// elements[i][it] = expf(elements[i][it] - max_value[i]);
elements
[
i
][
it
]
=
std
::
exp
(
elements
[
i
][
it
]
-
max_value
[
i
]);
sum
[
i
]
+=
elements
[
i
][
it
];
}
}
// reduction sum
#pragma unroll
for
(
int
offset
=
WARP_SIZE
/
2
;
offset
>
0
;
offset
/=
2
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
sum
[
i
]
+=
APEX_WARP_SHFL_XOR
(
FULL_MASK
,
sum
[
i
],
offset
,
WARP_SIZE
);
}
}
// store result
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
if
(
i
>=
local_batches
)
break
;
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
it
+=
ELEMENTS_PER_LDG_STG
)
{
int
element_index
=
ELEMENTS_PER_LDG_STG
*
local_idx
+
it
*
WARP_SIZE
;
if
(
element_index
<
element_count
)
{
// dst[i * element_count + it * WARP_SIZE] = elements[i][it] / sum[i];
output_t
out
[
ELEMENTS_PER_LDG_STG
];
for
(
int
element
=
0
;
element
<
ELEMENTS_PER_LDG_STG
;
++
element
)
{
out
[
element
]
=
elements
[
i
][
it
+
element
]
/
sum
[
i
];
}
copy_vector
<
output_t
,
ELEMENTS_PER_LDG_STG
>
(
dst
+
i
*
element_count
+
it
*
WARP_SIZE
,
out
);
}
else
{
break
;
}
}
}
}
// WARP_BATCH number of batches.
// WARP_ITERATOINS The number of iterations required for one warp to iterate
// over all data. WARP_SIZE number of elements working on a single batch, has to
// be a power of two. ELEMENTS_PER_LDG_STG has to be 1.
template
<
typename
input_t
,
typename
output_t
>
using
softmax_forward_func
=
void
(
*
)(
input_t
*
dst
,
const
output_t
*
src
,
int
batch_size
,
int
stride
,
int
element_count
);
template
<
typename
input_t
,
typename
output_t
,
typename
acc_t
>
bool
warp_softmax_kernel
(
int
log2_elements
,
int
&
warp_size
,
int
&
batches_per_warp
,
softmax_forward_func
<
input_t
,
output_t
>
&
kernel
)
{
// determine size of a warp
const
int
next_power_of_two
=
1
<<
log2_elements
;
warp_size
=
(
next_power_of_two
<
32
)
?
next_power_of_two
:
32
;
// determine how many batches a warp should process.
batches_per_warp
=
(
next_power_of_two
<=
128
)
?
2
:
1
;
switch
(
log2_elements
)
{
case
0
:
// 1
kernel
=
&
softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
2
,
1
,
1
,
1
>
;
break
;
case
1
:
// 2
kernel
=
&
softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
2
,
1
,
2
,
1
>
;
break
;
case
2
:
// 4
kernel
=
&
softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
2
,
1
,
4
,
1
>
;
break
;
case
3
:
// 8
kernel
=
&
softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
2
,
1
,
8
,
1
>
;
break
;
case
4
:
// 16
kernel
=
&
softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
2
,
1
,
16
,
1
>
;
break
;
case
5
:
// 32
kernel
=
&
softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
2
,
1
,
32
,
1
>
;
break
;
case
6
:
// 64
kernel
=
&
softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
2
,
2
,
32
,
1
>
;
break
;
case
7
:
// 128
kernel
=
&
softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
2
,
4
,
32
,
1
>
;
break
;
case
8
:
// 256
kernel
=
&
softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
1
,
8
,
32
,
1
>
;
break
;
case
9
:
// 512
kernel
=
&
softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
1
,
16
,
32
,
1
>
;
break
;
case
10
:
// 1024
kernel
=
&
softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
1
,
32
,
32
,
1
>
;
break
;
default:
return
false
;
}
return
true
;
}
template
<
typename
input_t
,
typename
output_t
,
typename
acc_t
>
bool
dispatch_softmax
(
output_t
*
dst
,
const
input_t
*
src
,
int
softmax_elements
,
int
softmax_elements_stride
,
int
batch_count
)
{
if
(
softmax_elements
==
0
)
{
return
true
;
}
else
if
(
softmax_elements
<=
1024
)
{
// compute function index. there's a function for each power of two size up
// to 1024.
int
log2_elements
=
0
;
while
((
1
<<
log2_elements
)
<
softmax_elements
)
++
log2_elements
;
softmax_forward_func
<
input_t
,
output_t
>
kernel
;
int
warp_size
,
batches_per_warp
;
if
(
!
warp_softmax_kernel
<
input_t
,
output_t
,
acc_t
>
(
log2_elements
,
warp_size
,
batches_per_warp
,
kernel
))
{
return
false
;
}
// use 128 threads per block to maximimize gpu utilization
constexpr
int
threads_per_block
=
128
;
// compute warps per block.
int
warps_per_block
=
(
threads_per_block
/
warp_size
);
// compute launch size
int
batches_per_block
=
warps_per_block
*
batches_per_warp
;
int
blocks
=
(
batch_count
+
batches_per_block
-
1
)
/
batches_per_block
;
dim3
threads
(
warp_size
,
warps_per_block
,
1
);
// launch
kernel
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
return
true
;
}
return
false
;
}
template
<
typename
input_t
,
typename
output_t
,
typename
acc_t
,
int
WARP_BATCH
,
int
WARP_ITERATIONS
,
int
WARP_SIZE
,
int
ELEMENTS_PER_LDG_STG
>
__global__
void
additive_masked_softmax_dropout_warp_forward_vec4
(
output_t
*
dst
,
uint8_t
*
dropout_mask
,
const
input_t
*
src
,
const
input_t
*
pad_mask
,
int
batch_size
,
int
stride
,
int
element_count
,
int
pad_batch_stride
,
at
::
PhiloxCudaState
philox_args
,
float
p
)
{
assert
(
ELEMENTS_PER_LDG_STG
==
4
);
int
first_batch
=
(
blockDim
.
y
*
blockIdx
.
x
+
threadIdx
.
y
)
*
WARP_BATCH
;
int
tid
=
blockIdx
.
x
*
blockDim
.
x
*
blockDim
.
y
+
threadIdx
.
y
*
blockDim
.
x
+
threadIdx
.
x
;
acc_t
pinv
=
acc_t
(
1
)
/
p
;
// batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP.
int
local_batches
=
batch_size
-
first_batch
;
if
(
local_batches
>
WARP_BATCH
)
local_batches
=
WARP_BATCH
;
// there might be multiple batches per warp. compute the index within the
// batch
int
local_idx
=
threadIdx
.
x
;
// vectorize if element_count is multiple of 4, else don't vectorize
input_t
elements_input
[
WARP_BATCH
][
WARP_ITERATIONS
];
int
thread_offset
=
first_batch
*
stride
+
ELEMENTS_PER_LDG_STG
*
local_idx
;
src
+=
thread_offset
;
dst
+=
thread_offset
;
dropout_mask
+=
thread_offset
;
// load data from global memory
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
int
batch_element_count
=
(
i
>=
local_batches
)
?
0
:
element_count
;
int
pad_thread_offset
=
((
first_batch
+
i
)
/
pad_batch_stride
)
*
stride
+
ELEMENTS_PER_LDG_STG
*
local_idx
;
const
half
*
curr_mask
=
pad_mask
+
pad_thread_offset
;
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
it
+=
ELEMENTS_PER_LDG_STG
)
{
int
element_index
=
ELEMENTS_PER_LDG_STG
*
local_idx
+
it
*
WARP_SIZE
;
#pragma unroll
for
(
int
element
=
0
;
element
<
ELEMENTS_PER_LDG_STG
;
++
element
)
{
// masking_value is a large negative value
elements_input
[
i
][
it
+
element
]
=
-
10000
;
}
if
(
element_index
<
batch_element_count
)
{
int
itr_jmp
=
it
*
WARP_SIZE
;
int
itr_idx
=
i
*
element_count
+
itr_jmp
;
copy_vector
<
input_t
,
ELEMENTS_PER_LDG_STG
>
(
&
elements_input
[
i
][
it
],
src
+
itr_idx
);
apply_additive_mask
<
input_t
,
ELEMENTS_PER_LDG_STG
>
(
&
elements_input
[
i
][
it
],
curr_mask
+
itr_jmp
);
//(__half)-std::numeric_limits<float>::infinity()
}
}
}
// convert input_t to acc_t
acc_t
elements
[
WARP_BATCH
][
WARP_ITERATIONS
];
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
++
it
)
{
elements
[
i
][
it
]
=
elements_input
[
i
][
it
];
}
}
constexpr
uint32_t
FULL_MASK
=
0xffffffff
;
// compute local max_value
// take the max_value of the first element to avoid one max call
acc_t
max_value
[
WARP_BATCH
];
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
max_value
[
i
]
=
elements
[
i
][
0
];
}
#pragma unroll
for
(
int
it
=
1
;
it
<
WARP_ITERATIONS
;
++
it
)
{
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
max_value
[
i
]
=
(
max_value
[
i
]
>
elements
[
i
][
it
])
?
max_value
[
i
]
:
elements
[
i
][
it
];
}
}
// reduction max_value
#pragma unroll
for
(
int
offset
=
WARP_SIZE
/
2
;
offset
>
0
;
offset
/=
2
)
{
float
val
[
WARP_BATCH
];
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
val
[
i
]
=
APEX_WARP_SHFL_XOR
(
FULL_MASK
,
max_value
[
i
],
offset
,
WARP_SIZE
);
}
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
max_value
[
i
]
=
max_value
[
i
]
>
val
[
i
]
?
max_value
[
i
]
:
val
[
i
];
}
}
// compute local sum
acc_t
sum
[
WARP_BATCH
]{
0.0
f
};
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
++
it
)
{
elements
[
i
][
it
]
=
std
::
exp
(
elements
[
i
][
it
]
-
max_value
[
i
]);
sum
[
i
]
+=
elements
[
i
][
it
];
}
}
// reduction sum
#pragma unroll
for
(
int
offset
=
WARP_SIZE
/
2
;
offset
>
0
;
offset
/=
2
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
sum
[
i
]
+=
APEX_WARP_SHFL_XOR
(
FULL_MASK
,
sum
[
i
],
offset
,
WARP_SIZE
);
}
}
auto
seeds
=
at
::
cuda
::
philox
::
unpack
(
philox_args
);
Philox
ph
(
std
::
get
<
0
>
(
seeds
),
tid
,
std
::
get
<
1
>
(
seeds
));
uint8_t
rands
[
WARP_BATCH
][
WARP_ITERATIONS
];
float4
rand_num
;
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
if
(
i
>=
local_batches
)
break
;
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
it
+=
ELEMENTS_PER_LDG_STG
)
{
int
element_index
=
ELEMENTS_PER_LDG_STG
*
local_idx
+
it
*
WARP_SIZE
;
if
(
element_index
<
element_count
)
{
rand_num
=
uniform4
(
ph
());
rands
[
i
][
it
]
=
(
rand_num
.
x
<=
p
)
>
0.5
;
rands
[
i
][
it
+
1
]
=
(
rand_num
.
y
<=
p
)
>
0.5
;
rands
[
i
][
it
+
2
]
=
(
rand_num
.
z
<=
p
)
>
0.5
;
rands
[
i
][
it
+
3
]
=
(
rand_num
.
w
<=
p
)
>
0.5
;
copy_vector
<
uint8_t
,
ELEMENTS_PER_LDG_STG
>
(
dropout_mask
+
i
*
element_count
+
it
*
WARP_SIZE
,
&
rands
[
i
][
it
]);
}
}
}
// store result
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
if
(
i
>=
local_batches
)
break
;
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
it
+=
ELEMENTS_PER_LDG_STG
)
{
int
element_index
=
ELEMENTS_PER_LDG_STG
*
local_idx
+
it
*
WARP_SIZE
;
if
(
element_index
<
element_count
)
{
output_t
out
[
ELEMENTS_PER_LDG_STG
];
#pragma unroll
for
(
int
element
=
0
;
element
<
ELEMENTS_PER_LDG_STG
;
++
element
)
{
out
[
element
]
=
rands
[
i
][
it
+
element
]
*
(
pinv
*
(
elements
[
i
][
it
+
element
]
/
sum
[
i
]));
}
copy_vector
<
output_t
,
ELEMENTS_PER_LDG_STG
>
(
dst
+
i
*
element_count
+
it
*
WARP_SIZE
,
out
);
}
else
{
break
;
}
}
}
}
template
<
typename
input_t
,
typename
output_t
,
typename
acc_t
,
int
WARP_BATCH
,
int
WARP_ITERATIONS
,
int
WARP_SIZE
,
int
ELEMENTS_PER_LDG_STG
>
__global__
void
additive_masked_softmax_dropout_warp_forward
(
output_t
*
dst
,
uint8_t
*
dropout_mask
,
const
input_t
*
src
,
const
input_t
*
pad_mask
,
int
batch_size
,
int
stride
,
int
element_count
,
int
pad_batch_stride
,
at
::
PhiloxCudaState
philox_args
,
float
p
)
{
assert
(
ELEMENTS_PER_LDG_STG
==
1
);
int
first_batch
=
(
blockDim
.
y
*
blockIdx
.
x
+
threadIdx
.
y
)
*
WARP_BATCH
;
int
tid
=
blockIdx
.
x
*
blockDim
.
x
*
blockDim
.
y
+
threadIdx
.
y
*
blockDim
.
x
+
threadIdx
.
x
;
acc_t
pinv
=
acc_t
(
1
)
/
p
;
// batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP.
int
local_batches
=
batch_size
-
first_batch
;
if
(
local_batches
>
WARP_BATCH
)
local_batches
=
WARP_BATCH
;
// there might be multiple batches per warp. compute the index within the
// batch
int
local_idx
=
threadIdx
.
x
;
// vectorize if element_count is multiple of 4, else don't vectorize
input_t
elements_input
[
WARP_BATCH
][
WARP_ITERATIONS
];
int
thread_offset
=
first_batch
*
stride
+
local_idx
;
src
+=
thread_offset
;
dst
+=
thread_offset
;
dropout_mask
+=
thread_offset
;
// load data from global memory
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
int
batch_element_count
=
(
i
>=
local_batches
)
?
0
:
element_count
;
int
pad_thread_offset
=
((
first_batch
+
i
)
/
pad_batch_stride
)
*
stride
+
local_idx
;
const
half
*
curr_mask
=
pad_mask
+
pad_thread_offset
;
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
it
+=
1
)
{
int
element_index
=
local_idx
+
it
*
WARP_SIZE
;
#pragma unroll
for
(
int
element
=
0
;
element
<
1
;
++
element
)
{
// masking_value is a large negative value
elements_input
[
i
][
it
+
element
]
=
-
10000
;
}
if
(
element_index
<
batch_element_count
)
{
int
itr_jmp
=
it
*
WARP_SIZE
;
int
itr_idx
=
i
*
element_count
+
itr_jmp
;
copy_vector
<
input_t
,
1
>
(
&
elements_input
[
i
][
it
],
src
+
itr_idx
);
apply_additive_mask
<
input_t
,
1
>
(
&
elements_input
[
i
][
it
],
curr_mask
+
itr_jmp
);
}
}
}
// convert input_t to acc_t
acc_t
elements
[
WARP_BATCH
][
WARP_ITERATIONS
];
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
++
it
)
{
elements
[
i
][
it
]
=
elements_input
[
i
][
it
];
}
}
constexpr
uint32_t
FULL_MASK
=
0xffffffff
;
// compute local max_value
// take the max_value of the first element to avoid one max call
acc_t
max_value
[
WARP_BATCH
];
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
max_value
[
i
]
=
elements
[
i
][
0
];
}
#pragma unroll
for
(
int
it
=
1
;
it
<
WARP_ITERATIONS
;
++
it
)
{
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
max_value
[
i
]
=
(
max_value
[
i
]
>
elements
[
i
][
it
])
?
max_value
[
i
]
:
elements
[
i
][
it
];
}
}
// reduction max_value
#pragma unroll
for
(
int
offset
=
WARP_SIZE
/
2
;
offset
>
0
;
offset
/=
2
)
{
float
val
[
WARP_BATCH
];
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
val
[
i
]
=
APEX_WARP_SHFL_XOR
(
FULL_MASK
,
max_value
[
i
],
offset
,
WARP_SIZE
);
}
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
max_value
[
i
]
=
max_value
[
i
]
>
val
[
i
]
?
max_value
[
i
]
:
val
[
i
];
}
}
// compute local sum
acc_t
sum
[
WARP_BATCH
]{
0.0
f
};
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
++
it
)
{
elements
[
i
][
it
]
=
std
::
exp
(
elements
[
i
][
it
]
-
max_value
[
i
]);
sum
[
i
]
+=
elements
[
i
][
it
];
}
}
// reduction sum
#pragma unroll
for
(
int
offset
=
WARP_SIZE
/
2
;
offset
>
0
;
offset
/=
2
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
sum
[
i
]
+=
APEX_WARP_SHFL_XOR
(
FULL_MASK
,
sum
[
i
],
offset
,
WARP_SIZE
);
}
}
curandStatePhilox4_32_10_t
state
;
auto
seeds
=
at
::
cuda
::
philox
::
unpack
(
philox_args
);
curand_init
(
std
::
get
<
0
>
(
seeds
),
tid
,
std
::
get
<
1
>
(
seeds
),
&
state
);
// store result
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
if
(
i
>=
local_batches
)
break
;
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
it
+=
1
)
{
int
element_index
=
local_idx
+
it
*
WARP_SIZE
;
if
(
element_index
<
element_count
)
{
output_t
out
[
1
];
acc_t
softmax_out
[
1
];
uint8_t
dropout_mask_temp
[
1
];
// generate a vector of random numbers here
float
rand
=
curand_uniform
(
&
state
);
float
*
rand_ptr
=
(
float
*
)(
&
rand
);
#pragma unroll
for
(
int
element
=
0
;
element
<
1
;
++
element
)
{
softmax_out
[
element
]
=
(
elements
[
i
][
it
+
element
]
/
sum
[
i
]);
rand_ptr
[
element
]
=
rand_ptr
[
element
]
<=
p
;
out
[
element
]
=
rand_ptr
[
element
]
*
pinv
*
softmax_out
[
element
];
dropout_mask_temp
[
element
]
=
rand_ptr
[
element
]
>
0.5
;
// just to distinguish 0.0f and 1.0f
}
copy_vector
<
output_t
,
1
>
(
dst
+
i
*
element_count
+
it
*
WARP_SIZE
,
out
);
copy_vector
<
uint8_t
,
1
>
(
dropout_mask
+
i
*
element_count
+
it
*
WARP_SIZE
,
dropout_mask_temp
);
}
else
{
break
;
}
}
}
}
// WARP_BATCH number of batches.
// WARP_ITERATOINS The number of iterations required for one warp to iterate
// over all data. WARP_SIZE number of elements working on a single batch, has to
// be a power of two. ELEMENTS_PER_LDG_STG has to be 1.
template
<
typename
input_t
,
typename
output_t
,
typename
acc_t
>
using
additive_masked_softmax_dropout_forward_func
=
void
(
*
)(
output_t
*
dst
,
uint8_t
*
dropout_mask
,
const
input_t
*
src
,
const
input_t
*
pad_mask
,
int
batch_size
,
int
stride
,
int
element_count
,
int
pad_batch_stride
,
at
::
PhiloxCudaState
philox_args
,
float
p
);
template
<
typename
input_t
,
typename
output_t
,
typename
acc_t
>
bool
warp_additive_masked_softmax_dropout_kernel
(
int
element_count
,
int
log2_elements
,
int
&
warp_size
,
int
&
batches_per_warp
,
additive_masked_softmax_dropout_forward_func
<
input_t
,
output_t
,
acc_t
>
&
kernel
)
{
// determine size of a warp
const
int
next_power_of_two
=
1
<<
log2_elements
;
warp_size
=
(
next_power_of_two
<
32
)
?
next_power_of_two
:
32
;
// determine how many batches a warp should process.
batches_per_warp
=
(
next_power_of_two
<=
128
)
?
2
:
1
;
bool
flag_vec4
=
(
element_count
%
4
==
0
);
switch
(
log2_elements
)
{
case
0
:
// 1
kernel
=
&
additive_masked_softmax_dropout_warp_forward
<
input_t
,
output_t
,
acc_t
,
2
,
1
,
1
,
1
>
;
break
;
case
1
:
// 2
kernel
=
&
additive_masked_softmax_dropout_warp_forward
<
input_t
,
output_t
,
acc_t
,
2
,
1
,
2
,
1
>
;
break
;
case
2
:
// 4
kernel
=
&
additive_masked_softmax_dropout_warp_forward
<
input_t
,
output_t
,
acc_t
,
2
,
1
,
4
,
1
>
;
break
;
case
3
:
// 8
kernel
=
&
additive_masked_softmax_dropout_warp_forward
<
input_t
,
output_t
,
acc_t
,
2
,
1
,
8
,
1
>
;
break
;
case
4
:
// 16
kernel
=
&
additive_masked_softmax_dropout_warp_forward
<
input_t
,
output_t
,
acc_t
,
2
,
1
,
16
,
1
>
;
break
;
case
5
:
// 32
kernel
=
&
additive_masked_softmax_dropout_warp_forward
<
input_t
,
output_t
,
acc_t
,
2
,
1
,
32
,
1
>
;
break
;
case
6
:
// 64
kernel
=
&
additive_masked_softmax_dropout_warp_forward
<
input_t
,
output_t
,
acc_t
,
2
,
2
,
32
,
1
>
;
break
;
case
7
:
// 128
if
(
flag_vec4
)
kernel
=
&
additive_masked_softmax_dropout_warp_forward_vec4
<
input_t
,
output_t
,
acc_t
,
2
,
4
,
32
,
4
>
;
else
kernel
=
&
additive_masked_softmax_dropout_warp_forward
<
input_t
,
output_t
,
acc_t
,
2
,
4
,
32
,
1
>
;
break
;
case
8
:
// 256
if
(
flag_vec4
)
kernel
=
&
additive_masked_softmax_dropout_warp_forward_vec4
<
input_t
,
output_t
,
acc_t
,
1
,
8
,
32
,
4
>
;
else
kernel
=
&
additive_masked_softmax_dropout_warp_forward
<
input_t
,
output_t
,
acc_t
,
1
,
8
,
32
,
1
>
;
break
;
case
9
:
// 512
if
(
flag_vec4
)
kernel
=
&
additive_masked_softmax_dropout_warp_forward_vec4
<
input_t
,
output_t
,
acc_t
,
1
,
16
,
32
,
4
>
;
else
kernel
=
&
additive_masked_softmax_dropout_warp_forward
<
input_t
,
output_t
,
acc_t
,
1
,
16
,
32
,
1
>
;
break
;
case
10
:
// 1024
if
(
flag_vec4
)
kernel
=
&
additive_masked_softmax_dropout_warp_forward_vec4
<
input_t
,
output_t
,
acc_t
,
1
,
32
,
32
,
4
>
;
else
kernel
=
&
additive_masked_softmax_dropout_warp_forward
<
input_t
,
output_t
,
acc_t
,
1
,
32
,
32
,
1
>
;
break
;
case
11
:
// 2048
if
(
flag_vec4
)
kernel
=
&
additive_masked_softmax_dropout_warp_forward_vec4
<
input_t
,
output_t
,
acc_t
,
1
,
64
,
32
,
4
>
;
else
kernel
=
&
additive_masked_softmax_dropout_warp_forward
<
input_t
,
output_t
,
acc_t
,
1
,
64
,
32
,
1
>
;
break
;
default:
return
false
;
}
return
true
;
}
template
<
typename
input_t
,
typename
output_t
,
typename
acc_t
>
bool
dispatch_additive_masked_softmax_dropout
(
output_t
*
dst
,
uint8_t
*
dropout_mask
,
const
input_t
*
src
,
const
input_t
*
pad_mask
,
int
totalElements
,
int
softmax_elements
,
int
softmax_elements_stride
,
int
batch_count
,
int
pad_batch_stride
,
float
p
,
cudaStream_t
streamid
)
// p is the probability to keep, not drop
{
if
(
softmax_elements
==
0
)
{
return
true
;
}
else
if
(
softmax_elements
<=
2048
)
{
// compute function index. there's a function for each power of two size up
// to 1024.
int
log2_elements
=
0
;
while
((
1
<<
log2_elements
)
<
softmax_elements
)
++
log2_elements
;
additive_masked_softmax_dropout_forward_func
<
input_t
,
output_t
,
acc_t
>
kernel
;
int
warp_size
,
batches_per_warp
;
if
(
!
warp_additive_masked_softmax_dropout_kernel
<
input_t
,
output_t
,
acc_t
>
(
softmax_elements
,
log2_elements
,
warp_size
,
batches_per_warp
,
kernel
))
{
return
false
;
}
// use 128 threads per block to maximimize gpu utilization
constexpr
int
threads_per_block
=
128
;
// compute warps per block.
int
warps_per_block
=
(
threads_per_block
/
warp_size
);
int
batches_per_block
=
warps_per_block
*
batches_per_warp
;
int
blocks
=
(
batch_count
+
batches_per_block
-
1
)
/
batches_per_block
;
c10
::
optional
<
at
::
Generator
>
gen_
;
auto
gen
=
at
::
get_generator_or_default
<
at
::
CUDAGeneratorImpl
>
(
gen_
,
at
::
cuda
::
detail
::
getDefaultCUDAGenerator
());
int64_t
counter_offset
=
(
totalElements
/
(
blocks
*
threads_per_block
)
+
1
);
at
::
PhiloxCudaState
rng_engine_inputs
;
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
gen
->
mutex_
);
rng_engine_inputs
=
gen
->
philox_cuda_state
(
counter_offset
);
}
// compute launch size
dim3
threads
(
warp_size
,
warps_per_block
,
1
);
// launch
kernel
<<<
blocks
,
threads
,
0
,
streamid
>>>
(
dst
,
dropout_mask
,
src
,
pad_mask
,
batch_count
,
softmax_elements_stride
,
softmax_elements
,
pad_batch_stride
,
rng_engine_inputs
,
p
);
return
true
;
}
return
false
;
}
// WARP_BATCH number of batches.
// WARP_ITERATOINS The number of iterations required for one warp to iterate
// over all data. WARP_SIZE number of elements working on a single batch, has to
// be a power of two. ELEMENTS_PER_LDG_STG has to be 1.
template
<
typename
input_t
,
typename
output_t
,
typename
acc_t
,
int
WARP_BATCH
,
int
WARP_ITERATIONS
,
int
WARP_SIZE
=
32
,
int
ELEMENTS_PER_LDG_STG
=
1
>
__global__
void
additive_masked_softmax_warp_forward
(
input_t
*
dst
,
const
output_t
*
src
,
const
input_t
*
pad_mask
,
int
batch_size
,
int
stride
,
int
element_count
,
int
pad_batch_stride
)
{
assert
(
ELEMENTS_PER_LDG_STG
==
1
);
int
first_batch
=
(
blockDim
.
y
*
blockIdx
.
x
+
threadIdx
.
y
)
*
WARP_BATCH
;
// batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP.
int
local_batches
=
batch_size
-
first_batch
;
if
(
local_batches
>
WARP_BATCH
)
local_batches
=
WARP_BATCH
;
// there might be multiple batches per warp. compute the index within the
// batch
int
local_idx
=
threadIdx
.
x
;
int
thread_offset
=
first_batch
*
stride
+
ELEMENTS_PER_LDG_STG
*
local_idx
;
src
+=
thread_offset
;
dst
+=
thread_offset
;
// load data from global memory
input_t
elements_input
[
WARP_BATCH
][
WARP_ITERATIONS
];
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
int
batch_element_count
=
(
i
>=
local_batches
)
?
0
:
element_count
;
int
pad_thread_offset
=
((
first_batch
+
i
)
/
pad_batch_stride
)
*
stride
+
ELEMENTS_PER_LDG_STG
*
local_idx
;
const
half
*
curr_mask
=
pad_mask
+
pad_thread_offset
;
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
it
+=
ELEMENTS_PER_LDG_STG
)
{
int
element_index
=
ELEMENTS_PER_LDG_STG
*
local_idx
+
it
*
WARP_SIZE
;
#pragma unroll
for
(
int
element
=
0
;
element
<
ELEMENTS_PER_LDG_STG
;
++
element
)
{
// masking_value is a large negative value
elements_input
[
i
][
it
+
element
]
=
-
10000
;
}
if
(
element_index
<
batch_element_count
)
{
int
itr_jmp
=
it
*
WARP_SIZE
;
int
itr_idx
=
i
*
element_count
+
itr_jmp
;
copy_vector
<
input_t
,
ELEMENTS_PER_LDG_STG
>
(
&
elements_input
[
i
][
it
],
src
+
itr_idx
);
// apply_mask<input_t, ELEMENTS_PER_LDG_STG>(&elements_input[i][it],
// (__half)-std::numeric_limits<float>::infinity(),
// curr_mask + itr_jmp);
elements_input
[
i
][
it
]
+=
*
(
curr_mask
+
itr_jmp
);
}
}
}
// convert input_t to acc_t
acc_t
elements
[
WARP_BATCH
][
WARP_ITERATIONS
];
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
++
it
)
{
elements
[
i
][
it
]
=
elements_input
[
i
][
it
];
}
}
constexpr
uint32_t
FULL_MASK
=
0xffffffff
;
// compute local max_value
// take the max_value of the first element to avoid one max call
acc_t
max_value
[
WARP_BATCH
];
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
max_value
[
i
]
=
elements
[
i
][
0
];
}
#pragma unroll
for
(
int
it
=
1
;
it
<
WARP_ITERATIONS
;
++
it
)
{
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
max_value
[
i
]
=
(
max_value
[
i
]
>
elements
[
i
][
it
])
?
max_value
[
i
]
:
elements
[
i
][
it
];
}
}
// reduction max_value
#pragma unroll
for
(
int
offset
=
WARP_SIZE
/
2
;
offset
>
0
;
offset
/=
2
)
{
float
val
[
WARP_BATCH
];
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
val
[
i
]
=
APEX_WARP_SHFL_XOR
(
FULL_MASK
,
max_value
[
i
],
offset
,
WARP_SIZE
);
}
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
max_value
[
i
]
=
max_value
[
i
]
>
val
[
i
]
?
max_value
[
i
]
:
val
[
i
];
}
}
// compute local sum
acc_t
sum
[
WARP_BATCH
]{
0.0
f
};
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
++
it
)
{
// elements[i][it] = expf(elements[i][it] - max_value[i]);
elements
[
i
][
it
]
=
std
::
exp
(
elements
[
i
][
it
]
-
max_value
[
i
]);
sum
[
i
]
+=
elements
[
i
][
it
];
}
}
// reduction sum
#pragma unroll
for
(
int
offset
=
WARP_SIZE
/
2
;
offset
>
0
;
offset
/=
2
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
sum
[
i
]
+=
APEX_WARP_SHFL_XOR
(
FULL_MASK
,
sum
[
i
],
offset
,
WARP_SIZE
);
}
}
// store result
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
if
(
i
>=
local_batches
)
break
;
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
it
+=
ELEMENTS_PER_LDG_STG
)
{
int
element_index
=
ELEMENTS_PER_LDG_STG
*
local_idx
+
it
*
WARP_SIZE
;
if
(
element_index
<
element_count
)
{
// dst[i * element_count + it * WARP_SIZE] = elements[i][it] / sum[i];
output_t
out
[
ELEMENTS_PER_LDG_STG
];
for
(
int
element
=
0
;
element
<
ELEMENTS_PER_LDG_STG
;
++
element
)
{
out
[
element
]
=
elements
[
i
][
it
+
element
]
/
sum
[
i
];
}
copy_vector
<
output_t
,
ELEMENTS_PER_LDG_STG
>
(
dst
+
i
*
element_count
+
it
*
WARP_SIZE
,
out
);
}
else
{
break
;
}
}
}
}
// WARP_BATCH number of batches.
// WARP_ITERATOINS The number of iterations required for one warp to iterate
// over all data. WARP_SIZE number of elements working on a single batch, has to
// be a power of two. ELEMENTS_PER_LDG_STG has to be 1.
template
<
typename
input_t
,
typename
output_t
>
using
additive_masked_softmax_forward_func
=
void
(
*
)(
input_t
*
dst
,
const
output_t
*
src
,
const
half
*
pad_mask
,
int
batch_size
,
int
stride
,
int
element_count
,
int
pad_batch_stride
);
template
<
typename
input_t
,
typename
output_t
,
typename
acc_t
>
bool
warp_additive_masked_softmax_kernel
(
int
log2_elements
,
int
&
warp_size
,
int
&
batches_per_warp
,
additive_masked_softmax_forward_func
<
input_t
,
output_t
>
&
kernel
)
{
// determine size of a warp
const
int
next_power_of_two
=
1
<<
log2_elements
;
warp_size
=
(
next_power_of_two
<
32
)
?
next_power_of_two
:
32
;
// determine how many batches a warp should process.
batches_per_warp
=
(
next_power_of_two
<=
128
)
?
2
:
1
;
switch
(
log2_elements
)
{
case
0
:
// 1
kernel
=
&
additive_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
2
,
1
,
1
,
1
>
;
break
;
case
1
:
// 2
kernel
=
&
additive_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
2
,
1
,
2
,
1
>
;
break
;
case
2
:
// 4
kernel
=
&
additive_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
2
,
1
,
4
,
1
>
;
break
;
case
3
:
// 8
kernel
=
&
additive_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
2
,
1
,
8
,
1
>
;
break
;
case
4
:
// 16
kernel
=
&
additive_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
2
,
1
,
16
,
1
>
;
break
;
case
5
:
// 32
kernel
=
&
additive_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
2
,
1
,
32
,
1
>
;
break
;
case
6
:
// 64
kernel
=
&
additive_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
2
,
2
,
32
,
1
>
;
break
;
case
7
:
// 128
kernel
=
&
additive_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
2
,
4
,
32
,
1
>
;
break
;
case
8
:
// 256
kernel
=
&
additive_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
1
,
8
,
32
,
1
>
;
break
;
case
9
:
// 512
kernel
=
&
additive_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
1
,
16
,
32
,
1
>
;
break
;
case
10
:
// 1024
kernel
=
&
additive_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
1
,
32
,
32
,
1
>
;
break
;
default:
return
false
;
}
return
true
;
}
template
<
typename
input_t
,
typename
output_t
,
typename
acc_t
>
bool
dispatch_additive_masked_softmax
(
output_t
*
dst
,
const
input_t
*
src
,
const
input_t
*
pad_mask
,
int
softmax_elements
,
int
softmax_elements_stride
,
int
batch_count
,
int
pad_batch_stride
)
{
if
(
softmax_elements
==
0
)
{
return
true
;
}
else
if
(
softmax_elements
<=
1024
)
{
// compute function index. there's a function for each power of two size up
// to 1024.
int
log2_elements
=
0
;
while
((
1
<<
log2_elements
)
<
softmax_elements
)
++
log2_elements
;
additive_masked_softmax_forward_func
<
input_t
,
output_t
>
kernel
;
int
warp_size
,
batches_per_warp
;
if
(
!
warp_additive_masked_softmax_kernel
<
input_t
,
output_t
,
acc_t
>
(
log2_elements
,
warp_size
,
batches_per_warp
,
kernel
))
{
return
false
;
}
// use 128 threads per block to maximimize gpu utilization
constexpr
int
threads_per_block
=
128
;
// compute warps per block.
int
warps_per_block
=
(
threads_per_block
/
warp_size
);
// compute launch size
int
batches_per_block
=
warps_per_block
*
batches_per_warp
;
int
blocks
=
(
batch_count
+
batches_per_block
-
1
)
/
batches_per_block
;
dim3
threads
(
warp_size
,
warps_per_block
,
1
);
// launch
kernel
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
pad_mask
,
batch_count
,
softmax_elements_stride
,
softmax_elements
,
pad_batch_stride
);
return
true
;
}
return
false
;
}
template
<
typename
input_t
,
typename
output_t
,
typename
acc_t
>
bool
dispatch_additive_masked_softmax_stream
(
output_t
*
dst
,
const
input_t
*
src
,
const
input_t
*
pad_mask
,
int
softmax_elements
,
int
softmax_elements_stride
,
int
batch_count
,
int
pad_batch_stride
,
cudaStream_t
streamid
)
{
if
(
softmax_elements
==
0
)
{
return
true
;
}
else
if
(
softmax_elements
<=
1024
)
{
// compute function index. there's a function for each power of two size up
// to 1024.
int
log2_elements
=
0
;
while
((
1
<<
log2_elements
)
<
softmax_elements
)
++
log2_elements
;
additive_masked_softmax_forward_func
<
input_t
,
output_t
>
kernel
;
int
warp_size
,
batches_per_warp
;
if
(
!
warp_additive_masked_softmax_kernel
<
input_t
,
output_t
,
acc_t
>
(
log2_elements
,
warp_size
,
batches_per_warp
,
kernel
))
{
return
false
;
}
// use 128 threads per block to maximimize gpu utilization
constexpr
int
threads_per_block
=
128
;
// compute warps per block.
int
warps_per_block
=
(
threads_per_block
/
warp_size
);
// compute launch size
int
batches_per_block
=
warps_per_block
*
batches_per_warp
;
int
blocks
=
(
batch_count
+
batches_per_block
-
1
)
/
batches_per_block
;
dim3
threads
(
warp_size
,
warps_per_block
,
1
);
// launch
kernel
<<<
blocks
,
threads
,
0
,
streamid
>>>
(
dst
,
src
,
pad_mask
,
batch_count
,
softmax_elements_stride
,
softmax_elements
,
pad_batch_stride
);
return
true
;
}
return
false
;
}
// WARP_BATCH number of batches.
// WARP_ITERATOINS The number of iterations required for one warp to iterate
// over all data. WARP_SIZE number of elements working on a single batch, has to
// be a power of two. ELEMENTS_PER_LDG_STG has to be 1.
template
<
typename
input_t
,
typename
output_t
,
typename
acc_t
,
int
WARP_BATCH
,
int
WARP_ITERATIONS
,
int
WARP_SIZE
=
32
,
int
ELEMENTS_PER_LDG_STG
=
1
>
__global__
void
masked_softmax_warp_forward
(
input_t
*
dst
,
const
output_t
*
src
,
const
uint8_t
*
pad_mask
,
int
batch_size
,
int
stride
,
int
element_count
,
int
pad_batch_stride
)
{
assert
(
ELEMENTS_PER_LDG_STG
==
1
);
int
first_batch
=
(
blockDim
.
y
*
blockIdx
.
x
+
threadIdx
.
y
)
*
WARP_BATCH
;
// batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP.
int
local_batches
=
batch_size
-
first_batch
;
if
(
local_batches
>
WARP_BATCH
)
local_batches
=
WARP_BATCH
;
// there might be multiple batches per warp. compute the index within the
// batch
int
local_idx
=
threadIdx
.
x
;
int
thread_offset
=
first_batch
*
stride
+
ELEMENTS_PER_LDG_STG
*
local_idx
;
src
+=
thread_offset
;
dst
+=
thread_offset
;
// load data from global memory
input_t
elements_input
[
WARP_BATCH
][
WARP_ITERATIONS
];
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
int
batch_element_count
=
(
i
>=
local_batches
)
?
0
:
element_count
;
int
pad_thread_offset
=
((
first_batch
+
i
)
/
pad_batch_stride
)
*
stride
+
ELEMENTS_PER_LDG_STG
*
local_idx
;
const
uint8_t
*
curr_mask
=
pad_mask
+
pad_thread_offset
;
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
it
+=
ELEMENTS_PER_LDG_STG
)
{
int
element_index
=
ELEMENTS_PER_LDG_STG
*
local_idx
+
it
*
WARP_SIZE
;
#pragma unroll
for
(
int
element
=
0
;
element
<
ELEMENTS_PER_LDG_STG
;
++
element
)
{
elements_input
[
i
][
it
+
element
]
=
-
std
::
numeric_limits
<
float
>::
infinity
();
}
if
(
element_index
<
batch_element_count
)
{
int
itr_jmp
=
it
*
WARP_SIZE
;
int
itr_idx
=
i
*
element_count
+
itr_jmp
;
copy_vector
<
input_t
,
ELEMENTS_PER_LDG_STG
>
(
&
elements_input
[
i
][
it
],
src
+
itr_idx
);
apply_mask
<
input_t
,
ELEMENTS_PER_LDG_STG
>
(
&
elements_input
[
i
][
it
],
(
__half
)
-
std
::
numeric_limits
<
float
>::
infinity
(),
curr_mask
+
itr_jmp
);
}
}
}
// convert input_t to acc_t
acc_t
elements
[
WARP_BATCH
][
WARP_ITERATIONS
];
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
++
it
)
{
elements
[
i
][
it
]
=
elements_input
[
i
][
it
];
}
}
constexpr
uint32_t
FULL_MASK
=
0xffffffff
;
// compute local max_value
// take the max_value of the first element to avoid one max call
acc_t
max_value
[
WARP_BATCH
];
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
max_value
[
i
]
=
elements
[
i
][
0
];
}
#pragma unroll
for
(
int
it
=
1
;
it
<
WARP_ITERATIONS
;
++
it
)
{
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
max_value
[
i
]
=
(
max_value
[
i
]
>
elements
[
i
][
it
])
?
max_value
[
i
]
:
elements
[
i
][
it
];
}
}
// reduction max_value
#pragma unroll
for
(
int
offset
=
WARP_SIZE
/
2
;
offset
>
0
;
offset
/=
2
)
{
float
val
[
WARP_BATCH
];
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
val
[
i
]
=
APEX_WARP_SHFL_XOR
(
FULL_MASK
,
max_value
[
i
],
offset
,
WARP_SIZE
);
}
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
max_value
[
i
]
=
max_value
[
i
]
>
val
[
i
]
?
max_value
[
i
]
:
val
[
i
];
}
}
// compute local sum
acc_t
sum
[
WARP_BATCH
]{
0.0
f
};
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
++
it
)
{
// elements[i][it] = expf(elements[i][it] - max_value[i]);
elements
[
i
][
it
]
=
std
::
exp
(
elements
[
i
][
it
]
-
max_value
[
i
]);
sum
[
i
]
+=
elements
[
i
][
it
];
}
}
// reduction sum
#pragma unroll
for
(
int
offset
=
WARP_SIZE
/
2
;
offset
>
0
;
offset
/=
2
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
sum
[
i
]
+=
APEX_WARP_SHFL_XOR
(
FULL_MASK
,
sum
[
i
],
offset
,
WARP_SIZE
);
}
}
// store result
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
if
(
i
>=
local_batches
)
break
;
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
it
+=
ELEMENTS_PER_LDG_STG
)
{
int
element_index
=
ELEMENTS_PER_LDG_STG
*
local_idx
+
it
*
WARP_SIZE
;
if
(
element_index
<
element_count
)
{
// dst[i * element_count + it * WARP_SIZE] = elements[i][it] / sum[i];
output_t
out
[
ELEMENTS_PER_LDG_STG
];
for
(
int
element
=
0
;
element
<
ELEMENTS_PER_LDG_STG
;
++
element
)
{
out
[
element
]
=
elements
[
i
][
it
+
element
]
/
sum
[
i
];
}
copy_vector
<
output_t
,
ELEMENTS_PER_LDG_STG
>
(
dst
+
i
*
element_count
+
it
*
WARP_SIZE
,
out
);
}
else
{
break
;
}
}
}
}
// WARP_BATCH number of batches.
// WARP_ITERATOINS The number of iterations required for one warp to iterate
// over all data. WARP_SIZE number of elements working on a single batch, has to
// be a power of two. ELEMENTS_PER_LDG_STG has to be 1.
template
<
typename
input_t
,
typename
output_t
>
using
masked_softmax_forward_func
=
void
(
*
)(
input_t
*
dst
,
const
output_t
*
src
,
const
uint8_t
*
pad_mask
,
int
batch_size
,
int
stride
,
int
element_count
,
int
pad_batch_stride
);
template
<
typename
input_t
,
typename
output_t
,
typename
acc_t
>
bool
warp_masked_softmax_kernel
(
int
log2_elements
,
int
&
warp_size
,
int
&
batches_per_warp
,
masked_softmax_forward_func
<
input_t
,
output_t
>
&
kernel
)
{
// determine size of a warp
const
int
next_power_of_two
=
1
<<
log2_elements
;
warp_size
=
(
next_power_of_two
<
32
)
?
next_power_of_two
:
32
;
// determine how many batches a warp should process.
batches_per_warp
=
(
next_power_of_two
<=
128
)
?
2
:
1
;
switch
(
log2_elements
)
{
case
0
:
// 1
kernel
=
&
masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
2
,
1
,
1
,
1
>
;
break
;
case
1
:
// 2
kernel
=
&
masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
2
,
1
,
2
,
1
>
;
break
;
case
2
:
// 4
kernel
=
&
masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
2
,
1
,
4
,
1
>
;
break
;
case
3
:
// 8
kernel
=
&
masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
2
,
1
,
8
,
1
>
;
break
;
case
4
:
// 16
kernel
=
&
masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
2
,
1
,
16
,
1
>
;
break
;
case
5
:
// 32
kernel
=
&
masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
2
,
1
,
32
,
1
>
;
break
;
case
6
:
// 64
kernel
=
&
masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
2
,
2
,
32
,
1
>
;
break
;
case
7
:
// 128
kernel
=
&
masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
2
,
4
,
32
,
1
>
;
break
;
case
8
:
// 256
kernel
=
&
masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
1
,
8
,
32
,
1
>
;
break
;
case
9
:
// 512
kernel
=
&
masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
1
,
16
,
32
,
1
>
;
break
;
case
10
:
// 1024
kernel
=
&
masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
1
,
32
,
32
,
1
>
;
break
;
default:
return
false
;
}
return
true
;
}
template
<
typename
input_t
,
typename
output_t
,
typename
acc_t
>
bool
dispatch_masked_softmax
(
output_t
*
dst
,
const
input_t
*
src
,
const
uint8_t
*
pad_mask
,
int
softmax_elements
,
int
softmax_elements_stride
,
int
batch_count
,
int
pad_batch_stride
)
{
if
(
softmax_elements
==
0
)
{
return
true
;
}
else
if
(
softmax_elements
<=
1024
)
{
// compute function index. there's a function for each power of two size up
// to 1024.
int
log2_elements
=
0
;
while
((
1
<<
log2_elements
)
<
softmax_elements
)
++
log2_elements
;
masked_softmax_forward_func
<
input_t
,
output_t
>
kernel
;
int
warp_size
,
batches_per_warp
;
if
(
!
warp_masked_softmax_kernel
<
input_t
,
output_t
,
acc_t
>
(
log2_elements
,
warp_size
,
batches_per_warp
,
kernel
))
{
return
false
;
}
// use 128 threads per block to maximimize gpu utilization
constexpr
int
threads_per_block
=
128
;
// compute warps per block.
int
warps_per_block
=
(
threads_per_block
/
warp_size
);
// compute launch size
int
batches_per_block
=
warps_per_block
*
batches_per_warp
;
int
blocks
=
(
batch_count
+
batches_per_block
-
1
)
/
batches_per_block
;
dim3
threads
(
warp_size
,
warps_per_block
,
1
);
// launch
kernel
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
pad_mask
,
batch_count
,
softmax_elements_stride
,
softmax_elements
,
pad_batch_stride
);
return
true
;
}
return
false
;
}
// WARP_BATCH number of batches.
// WARP_ITERATOINS The number of iterations required for one warp to iterate
// over all data. WARP_SIZE number of elements working on a single batch, has to
// be a power of two. ELEMENTS_PER_LDG_STG has to be 1.
template
<
typename
input_t
,
typename
output_t
,
typename
acc_t
,
int
WARP_BATCH
,
int
WARP_ITERATIONS
,
int
WARP_SIZE
=
32
,
int
ELEMENTS_PER_LDG_STG
=
1
>
__global__
void
time_masked_softmax_warp_forward
(
input_t
*
dst
,
const
output_t
*
src
,
const
uint8_t
*
pad_mask
,
int
batch_size
,
int
stride
,
int
element_count
,
int
mod_seq_len
)
{
assert
(
ELEMENTS_PER_LDG_STG
==
1
);
int
first_batch
=
(
blockDim
.
y
*
blockIdx
.
x
+
threadIdx
.
y
)
*
WARP_BATCH
;
// batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP.
int
local_batches
=
batch_size
-
first_batch
;
if
(
local_batches
>
WARP_BATCH
)
local_batches
=
WARP_BATCH
;
// there might be multiple batches per warp. compute the index within the
// batch
int
local_idx
=
threadIdx
.
x
;
int
thread_offset
=
first_batch
*
stride
+
ELEMENTS_PER_LDG_STG
*
local_idx
;
src
+=
thread_offset
;
dst
+=
thread_offset
;
// load data from global memory
input_t
elements_input
[
WARP_BATCH
][
WARP_ITERATIONS
];
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
int
batch_element_count
=
(
i
>=
local_batches
)
?
0
:
element_count
;
int
pad_thread_offset
=
((
first_batch
+
i
)
%
mod_seq_len
)
*
stride
+
ELEMENTS_PER_LDG_STG
*
local_idx
;
const
uint8_t
*
curr_mask
=
pad_mask
+
pad_thread_offset
;
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
it
+=
ELEMENTS_PER_LDG_STG
)
{
int
element_index
=
ELEMENTS_PER_LDG_STG
*
local_idx
+
it
*
WARP_SIZE
;
#pragma unroll
for
(
int
element
=
0
;
element
<
ELEMENTS_PER_LDG_STG
;
++
element
)
{
elements_input
[
i
][
it
+
element
]
=
-
std
::
numeric_limits
<
float
>::
infinity
();
}
if
(
element_index
<
batch_element_count
)
{
int
itr_jmp
=
it
*
WARP_SIZE
;
int
itr_idx
=
i
*
element_count
+
itr_jmp
;
copy_vector
<
input_t
,
ELEMENTS_PER_LDG_STG
>
(
&
elements_input
[
i
][
it
],
src
+
itr_idx
);
apply_mask
<
input_t
,
ELEMENTS_PER_LDG_STG
>
(
&
elements_input
[
i
][
it
],
(
__half
)
-
std
::
numeric_limits
<
float
>::
infinity
(),
curr_mask
+
itr_jmp
);
}
}
}
// convert input_t to acc_t
acc_t
elements
[
WARP_BATCH
][
WARP_ITERATIONS
];
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
++
it
)
{
elements
[
i
][
it
]
=
elements_input
[
i
][
it
];
}
}
constexpr
uint32_t
FULL_MASK
=
0xffffffff
;
// compute local max_value
// take the max_value of the first element to avoid one max call
acc_t
max_value
[
WARP_BATCH
];
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
max_value
[
i
]
=
elements
[
i
][
0
];
}
#pragma unroll
for
(
int
it
=
1
;
it
<
WARP_ITERATIONS
;
++
it
)
{
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
max_value
[
i
]
=
(
max_value
[
i
]
>
elements
[
i
][
it
])
?
max_value
[
i
]
:
elements
[
i
][
it
];
}
}
// reduction max_value
#pragma unroll
for
(
int
offset
=
WARP_SIZE
/
2
;
offset
>
0
;
offset
/=
2
)
{
float
val
[
WARP_BATCH
];
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
val
[
i
]
=
APEX_WARP_SHFL_XOR
(
FULL_MASK
,
max_value
[
i
],
offset
,
WARP_SIZE
);
}
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
max_value
[
i
]
=
max_value
[
i
]
>
val
[
i
]
?
max_value
[
i
]
:
val
[
i
];
}
}
// compute local sum
acc_t
sum
[
WARP_BATCH
]{
0.0
f
};
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
++
it
)
{
// elements[i][it] = expf(elements[i][it] - max_value[i]);
elements
[
i
][
it
]
=
std
::
exp
(
elements
[
i
][
it
]
-
max_value
[
i
]);
sum
[
i
]
+=
elements
[
i
][
it
];
}
}
// reduction sum
#pragma unroll
for
(
int
offset
=
WARP_SIZE
/
2
;
offset
>
0
;
offset
/=
2
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
sum
[
i
]
+=
APEX_WARP_SHFL_XOR
(
FULL_MASK
,
sum
[
i
],
offset
,
WARP_SIZE
);
}
}
// store result
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
if
(
i
>=
local_batches
)
break
;
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
it
+=
ELEMENTS_PER_LDG_STG
)
{
int
element_index
=
ELEMENTS_PER_LDG_STG
*
local_idx
+
it
*
WARP_SIZE
;
if
(
element_index
<
element_count
)
{
// dst[i * element_count + it * WARP_SIZE] = elements[i][it] / sum[i];
output_t
out
[
ELEMENTS_PER_LDG_STG
];
for
(
int
element
=
0
;
element
<
ELEMENTS_PER_LDG_STG
;
++
element
)
{
out
[
element
]
=
elements
[
i
][
it
+
element
]
/
sum
[
i
];
}
copy_vector
<
output_t
,
ELEMENTS_PER_LDG_STG
>
(
dst
+
i
*
element_count
+
it
*
WARP_SIZE
,
out
);
}
else
{
break
;
}
}
}
}
// WARP_BATCH number of batches.
// WARP_ITERATOINS The number of iterations required for one warp to iterate
// over all data. WARP_SIZE number of elements working on a single batch, has to
// be a power of two. ELEMENTS_PER_LDG_STG has to be 1.
template
<
typename
input_t
,
typename
output_t
>
using
time_masked_softmax_forward_func
=
void
(
*
)(
input_t
*
dst
,
const
output_t
*
src
,
const
uint8_t
*
pad_mask
,
int
batch_size
,
int
stride
,
int
element_count
,
int
mod_seq_len
);
template
<
typename
input_t
,
typename
output_t
,
typename
acc_t
>
bool
warp_time_masked_softmax_kernel
(
int
log2_elements
,
int
&
warp_size
,
int
&
batches_per_warp
,
time_masked_softmax_forward_func
<
input_t
,
output_t
>
&
kernel
)
{
// determine size of a warp
const
int
next_power_of_two
=
1
<<
log2_elements
;
warp_size
=
(
next_power_of_two
<
32
)
?
next_power_of_two
:
32
;
// determine how many batches a warp should process.
batches_per_warp
=
(
next_power_of_two
<=
128
)
?
2
:
1
;
switch
(
log2_elements
)
{
case
0
:
// 1
kernel
=
&
time_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
2
,
1
,
1
,
1
>
;
break
;
case
1
:
// 2
kernel
=
&
time_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
2
,
1
,
2
,
1
>
;
break
;
case
2
:
// 4
kernel
=
&
time_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
2
,
1
,
4
,
1
>
;
break
;
case
3
:
// 8
kernel
=
&
time_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
2
,
1
,
8
,
1
>
;
break
;
case
4
:
// 16
kernel
=
&
time_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
2
,
1
,
16
,
1
>
;
break
;
case
5
:
// 32
kernel
=
&
time_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
2
,
1
,
32
,
1
>
;
break
;
case
6
:
// 64
kernel
=
&
time_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
2
,
2
,
32
,
1
>
;
break
;
case
7
:
// 128
kernel
=
&
time_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
2
,
4
,
32
,
1
>
;
break
;
case
8
:
// 256
kernel
=
&
time_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
1
,
8
,
32
,
1
>
;
break
;
case
9
:
// 512
kernel
=
&
time_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
1
,
16
,
32
,
1
>
;
break
;
case
10
:
// 1024
kernel
=
&
time_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
1
,
32
,
32
,
1
>
;
break
;
default:
return
false
;
}
return
true
;
}
template
<
typename
input_t
,
typename
output_t
,
typename
acc_t
>
bool
dispatch_time_masked_softmax
(
output_t
*
dst
,
const
input_t
*
src
,
const
uint8_t
*
pad_mask
,
int
softmax_elements
,
int
softmax_elements_stride
,
int
batch_count
,
int
mod_seq_len
)
{
if
(
softmax_elements
==
0
)
{
return
true
;
}
else
if
(
softmax_elements
<=
1024
)
{
// compute function index. there's a function for each power of two size up
// to 1024.
int
log2_elements
=
0
;
while
((
1
<<
log2_elements
)
<
softmax_elements
)
++
log2_elements
;
time_masked_softmax_forward_func
<
input_t
,
output_t
>
kernel
;
int
warp_size
,
batches_per_warp
;
if
(
!
warp_time_masked_softmax_kernel
<
input_t
,
output_t
,
acc_t
>
(
log2_elements
,
warp_size
,
batches_per_warp
,
kernel
))
{
return
false
;
}
// use 128 threads per block to maximimize gpu utilization
constexpr
int
threads_per_block
=
128
;
// compute warps per block.
int
warps_per_block
=
(
threads_per_block
/
warp_size
);
// compute launch size
int
batches_per_block
=
warps_per_block
*
batches_per_warp
;
int
blocks
=
(
batch_count
+
batches_per_block
-
1
)
/
batches_per_block
;
dim3
threads
(
warp_size
,
warps_per_block
,
1
);
// launch
kernel
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
pad_mask
,
batch_count
,
softmax_elements_stride
,
softmax_elements
,
mod_seq_len
);
return
true
;
}
return
false
;
}
int
log2_ceil_native
(
int
value
)
{
int
log2_value
=
0
;
while
((
1
<<
log2_value
)
<
value
)
++
log2_value
;
return
log2_value
;
}
template
<
typename
T
>
__device__
__forceinline__
T
WARP_SHFL_XOR_NATIVE
(
T
value
,
int
laneMask
,
int
width
=
warpSize
,
unsigned
int
mask
=
0xffffffff
)
{
#if CUDA_VERSION >= 9000
return
__shfl_xor_sync
(
mask
,
value
,
laneMask
,
width
);
#else
return
__shfl_xor
(
value
,
laneMask
,
width
);
#endif
}
template
<
typename
acc_t
,
int
WARP_BATCH
,
int
WARP_SIZE
>
__device__
__forceinline__
void
warp_reduce_sum
(
acc_t
*
sum
)
{
#pragma unroll
for
(
int
offset
=
WARP_SIZE
/
2
;
offset
>
0
;
offset
/=
2
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
acc_t
b
=
WARP_SHFL_XOR_NATIVE
(
sum
[
i
],
offset
,
WARP_SIZE
);
sum
[
i
]
=
sum
[
i
]
+
b
;
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
// Warp softmax backward functions as fused variants of
// at::softmax_backward_data function
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
// softmax backward data function is taken from native pytorch, elementwise mul
// is fused in the epolog, as well as masking and scaling for fusing dropout
template
<
typename
input_t
,
typename
output_t
,
typename
acc_t
,
int
log2_elements
,
bool
is_log_softmax
>
__global__
void
masked_scale_softmax_warp_backward_masked_dgrad
(
output_t
*
gradInput
,
const
input_t
*
grad
,
const
input_t
*
output
,
const
uint8_t
*
mask
,
const
uint8_t
*
pad_mask
,
acc_t
scale
,
int
batch_size
,
int
stride
,
int
element_count
,
int
heads
)
{
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
// warp_size of method warp_softmax_backward_kernel.
constexpr
int
next_power_of_two
=
1
<<
log2_elements
;
constexpr
int
WARP_SIZE
=
(
next_power_of_two
<
C10_WARP_SIZE
)
?
next_power_of_two
:
C10_WARP_SIZE
;
constexpr
int
WARP_ITERATIONS
=
next_power_of_two
/
WARP_SIZE
;
constexpr
int
WARP_BATCH
=
(
next_power_of_two
<=
128
)
?
2
:
1
;
int
first_batch
=
(
blockDim
.
y
*
blockIdx
.
x
+
threadIdx
.
y
)
*
WARP_BATCH
;
// batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP.
int
local_batches
=
batch_size
-
first_batch
;
if
(
local_batches
>
WARP_BATCH
)
local_batches
=
WARP_BATCH
;
// there might be multiple batches per warp. compute the index within the
// batch
int
local_idx
=
threadIdx
.
x
%
WARP_SIZE
;
// the first element to process by the current thread
int
thread_offset
=
first_batch
*
stride
+
local_idx
;
grad
+=
thread_offset
;
output
+=
thread_offset
;
gradInput
+=
thread_offset
;
mask
+=
thread_offset
;
// The nested loops over WARP_BATCH and then WARP_ITERATIONS can be simplified
// to one loop, but I think doing so would obfuscate the logic of the
// algorithm, thus I chose to keep the nested loops. This should have no
// impact on performance because the loops are unrolled anyway.
// load data from global memory
acc_t
grad_reg
[
WARP_BATCH
][
WARP_ITERATIONS
];
acc_t
output_reg
[
WARP_BATCH
][
WARP_ITERATIONS
];
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
int
batch_element_count
=
(
i
>=
local_batches
)
?
0
:
element_count
;
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
++
it
)
{
int
element_index
=
local_idx
+
it
*
WARP_SIZE
;
if
(
element_index
<
batch_element_count
)
{
grad_reg
[
i
][
it
]
=
(
input_t
)((
acc_t
)
mask
[
i
*
element_count
+
it
*
WARP_SIZE
]
*
(
acc_t
)
grad
[
i
*
element_count
+
it
*
WARP_SIZE
]
*
(
acc_t
)
scale
)
*
output
[
i
*
element_count
+
it
*
WARP_SIZE
];
output_reg
[
i
][
it
]
=
output
[
i
*
element_count
+
it
*
WARP_SIZE
];
}
else
{
grad_reg
[
i
][
it
]
=
acc_t
(
0
);
output_reg
[
i
][
it
]
=
acc_t
(
0
);
}
}
}
acc_t
sum
[
WARP_BATCH
];
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
sum
[
i
]
=
grad_reg
[
i
][
0
];
#pragma unroll
for
(
int
it
=
1
;
it
<
WARP_ITERATIONS
;
++
it
)
{
sum
[
i
]
+=
grad_reg
[
i
][
it
];
}
}
warp_reduce_sum
<
acc_t
,
WARP_BATCH
,
WARP_SIZE
>
(
sum
);
// store result
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
if
(
i
>=
local_batches
)
break
;
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
++
it
)
{
int
element_index
=
local_idx
+
it
*
WARP_SIZE
;
if
(
element_index
<
element_count
)
{
// compute gradients
int
total_ind
=
thread_offset
+
i
*
element_count
+
it
*
WARP_SIZE
;
int
pad_mask_ind
=
element_count
*
(
total_ind
/
(
heads
*
element_count
*
element_count
))
+
total_ind
%
element_count
;
uint8_t
pad_mask_element
=
1
-
pad_mask
[
pad_mask_ind
];
if
(
pad_mask_element
==
0
)
gradInput
[
i
*
element_count
+
it
*
WARP_SIZE
]
=
0
;
else
{
if
(
is_log_softmax
)
{
gradInput
[
i
*
element_count
+
it
*
WARP_SIZE
]
=
(
grad_reg
[
i
][
it
]
-
std
::
exp
(
output_reg
[
i
][
it
])
*
sum
[
i
]);
}
else
{
gradInput
[
i
*
element_count
+
it
*
WARP_SIZE
]
=
(
grad_reg
[
i
][
it
]
-
output_reg
[
i
][
it
]
*
sum
[
i
]);
}
}
}
}
}
}
template
<
typename
input_t
,
typename
output_t
,
typename
acc_t
,
bool
is_log_softmax
>
void
dispatch_masked_scale_softmax_backward_masked_out
(
output_t
*
grad_input
,
const
input_t
*
grad
,
const
input_t
*
output
,
const
uint8_t
*
mask
,
const
uint8_t
*
pad_mask
,
acc_t
scale
,
int
softmax_elements
,
int
softmax_elements_stride
,
int
batch_count
,
int
heads
)
{
TORCH_INTERNAL_ASSERT
(
softmax_elements
>=
0
&&
softmax_elements
<=
1024
);
if
(
softmax_elements
==
0
)
{
return
;
}
else
{
int
log2_elements
=
log2_ceil_native
(
softmax_elements
);
const
int
next_power_of_two
=
1
<<
log2_elements
;
// This value must match the WARP_SIZE constexpr value computed inside
// softmax_warp_backward.
int
warp_size
=
(
next_power_of_two
<
C10_WARP_SIZE
)
?
next_power_of_two
:
C10_WARP_SIZE
;
// This value must match the WARP_BATCH constexpr value computed inside
// softmax_warp_backward.
int
batches_per_warp
=
(
next_power_of_two
<=
128
)
?
2
:
1
;
// use 128 threads per block to maximimize gpu utilization
constexpr
int
threads_per_block
=
128
;
int
warps_per_block
=
(
threads_per_block
/
warp_size
);
int
batches_per_block
=
warps_per_block
*
batches_per_warp
;
int
blocks
=
(
batch_count
+
batches_per_block
-
1
)
/
batches_per_block
;
dim3
threads
(
warp_size
,
warps_per_block
,
1
);
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
switch
(
log2_elements
)
{
case
0
:
// 1
masked_scale_softmax_warp_backward_masked_dgrad
<
input_t
,
output_t
,
acc_t
,
0
,
is_log_softmax
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
mask
,
pad_mask
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
,
heads
);
break
;
case
1
:
// 2
masked_scale_softmax_warp_backward_masked_dgrad
<
input_t
,
output_t
,
acc_t
,
1
,
is_log_softmax
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
mask
,
pad_mask
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
,
heads
);
break
;
case
2
:
// 4
masked_scale_softmax_warp_backward_masked_dgrad
<
input_t
,
output_t
,
acc_t
,
2
,
is_log_softmax
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
mask
,
pad_mask
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
,
heads
);
break
;
case
3
:
// 8
masked_scale_softmax_warp_backward_masked_dgrad
<
input_t
,
output_t
,
acc_t
,
3
,
is_log_softmax
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
mask
,
pad_mask
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
,
heads
);
break
;
case
4
:
// 16
masked_scale_softmax_warp_backward_masked_dgrad
<
input_t
,
output_t
,
acc_t
,
4
,
is_log_softmax
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
mask
,
pad_mask
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
,
heads
);
break
;
case
5
:
// 32
masked_scale_softmax_warp_backward_masked_dgrad
<
input_t
,
output_t
,
acc_t
,
5
,
is_log_softmax
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
mask
,
pad_mask
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
,
heads
);
break
;
case
6
:
// 64
masked_scale_softmax_warp_backward_masked_dgrad
<
input_t
,
output_t
,
acc_t
,
6
,
is_log_softmax
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
mask
,
pad_mask
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
,
heads
);
break
;
case
7
:
// 128
masked_scale_softmax_warp_backward_masked_dgrad
<
input_t
,
output_t
,
acc_t
,
7
,
is_log_softmax
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
mask
,
pad_mask
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
,
heads
);
break
;
case
8
:
// 256
masked_scale_softmax_warp_backward_masked_dgrad
<
input_t
,
output_t
,
acc_t
,
8
,
is_log_softmax
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
mask
,
pad_mask
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
,
heads
);
break
;
case
9
:
// 512
masked_scale_softmax_warp_backward_masked_dgrad
<
input_t
,
output_t
,
acc_t
,
9
,
is_log_softmax
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
mask
,
pad_mask
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
,
heads
);
break
;
case
10
:
// 1024
masked_scale_softmax_warp_backward_masked_dgrad
<
input_t
,
output_t
,
acc_t
,
10
,
is_log_softmax
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
mask
,
pad_mask
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
,
heads
);
break
;
default:
break
;
}
}
}
template
<
typename
input_t
,
typename
output_t
,
typename
acc_t
,
bool
is_log_softmax
>
void
dispatch_masked_scale_softmax_backward_masked_out_stream
(
output_t
*
grad_input
,
const
input_t
*
grad
,
const
input_t
*
output
,
const
uint8_t
*
mask
,
const
uint8_t
*
pad_mask
,
acc_t
scale
,
int
softmax_elements
,
int
softmax_elements_stride
,
int
batch_count
,
int
heads
,
cudaStream_t
streamid
)
{
TORCH_INTERNAL_ASSERT
(
softmax_elements
>=
0
&&
softmax_elements
<=
1024
);
if
(
softmax_elements
==
0
)
{
return
;
}
else
{
int
log2_elements
=
log2_ceil_native
(
softmax_elements
);
const
int
next_power_of_two
=
1
<<
log2_elements
;
// This value must match the WARP_SIZE constexpr value computed inside
// softmax_warp_backward.
int
warp_size
=
(
next_power_of_two
<
C10_WARP_SIZE
)
?
next_power_of_two
:
C10_WARP_SIZE
;
// This value must match the WARP_BATCH constexpr value computed inside
// softmax_warp_backward.
int
batches_per_warp
=
(
next_power_of_two
<=
128
)
?
2
:
1
;
// use 128 threads per block to maximimize gpu utilization
constexpr
int
threads_per_block
=
128
;
int
warps_per_block
=
(
threads_per_block
/
warp_size
);
int
batches_per_block
=
warps_per_block
*
batches_per_warp
;
int
blocks
=
(
batch_count
+
batches_per_block
-
1
)
/
batches_per_block
;
dim3
threads
(
warp_size
,
warps_per_block
,
1
);
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
switch
(
log2_elements
)
{
case
0
:
// 1
masked_scale_softmax_warp_backward_masked_dgrad
<
input_t
,
output_t
,
acc_t
,
0
,
is_log_softmax
>
<<<
blocks
,
threads
,
0
,
streamid
>>>
(
grad_input
,
grad
,
output
,
mask
,
pad_mask
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
,
heads
);
break
;
case
1
:
// 2
masked_scale_softmax_warp_backward_masked_dgrad
<
input_t
,
output_t
,
acc_t
,
1
,
is_log_softmax
>
<<<
blocks
,
threads
,
0
,
streamid
>>>
(
grad_input
,
grad
,
output
,
mask
,
pad_mask
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
,
heads
);
break
;
case
2
:
// 4
masked_scale_softmax_warp_backward_masked_dgrad
<
input_t
,
output_t
,
acc_t
,
2
,
is_log_softmax
>
<<<
blocks
,
threads
,
0
,
streamid
>>>
(
grad_input
,
grad
,
output
,
mask
,
pad_mask
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
,
heads
);
break
;
case
3
:
// 8
masked_scale_softmax_warp_backward_masked_dgrad
<
input_t
,
output_t
,
acc_t
,
3
,
is_log_softmax
>
<<<
blocks
,
threads
,
0
,
streamid
>>>
(
grad_input
,
grad
,
output
,
mask
,
pad_mask
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
,
heads
);
break
;
case
4
:
// 16
masked_scale_softmax_warp_backward_masked_dgrad
<
input_t
,
output_t
,
acc_t
,
4
,
is_log_softmax
>
<<<
blocks
,
threads
,
0
,
streamid
>>>
(
grad_input
,
grad
,
output
,
mask
,
pad_mask
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
,
heads
);
break
;
case
5
:
// 32
masked_scale_softmax_warp_backward_masked_dgrad
<
input_t
,
output_t
,
acc_t
,
5
,
is_log_softmax
>
<<<
blocks
,
threads
,
0
,
streamid
>>>
(
grad_input
,
grad
,
output
,
mask
,
pad_mask
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
,
heads
);
break
;
case
6
:
// 64
masked_scale_softmax_warp_backward_masked_dgrad
<
input_t
,
output_t
,
acc_t
,
6
,
is_log_softmax
>
<<<
blocks
,
threads
,
0
,
streamid
>>>
(
grad_input
,
grad
,
output
,
mask
,
pad_mask
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
,
heads
);
break
;
case
7
:
// 128
masked_scale_softmax_warp_backward_masked_dgrad
<
input_t
,
output_t
,
acc_t
,
7
,
is_log_softmax
>
<<<
blocks
,
threads
,
0
,
streamid
>>>
(
grad_input
,
grad
,
output
,
mask
,
pad_mask
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
,
heads
);
break
;
case
8
:
// 256
masked_scale_softmax_warp_backward_masked_dgrad
<
input_t
,
output_t
,
acc_t
,
8
,
is_log_softmax
>
<<<
blocks
,
threads
,
0
,
streamid
>>>
(
grad_input
,
grad
,
output
,
mask
,
pad_mask
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
,
heads
);
break
;
case
9
:
// 512
masked_scale_softmax_warp_backward_masked_dgrad
<
input_t
,
output_t
,
acc_t
,
9
,
is_log_softmax
>
<<<
blocks
,
threads
,
0
,
streamid
>>>
(
grad_input
,
grad
,
output
,
mask
,
pad_mask
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
,
heads
);
break
;
case
10
:
// 1024
masked_scale_softmax_warp_backward_masked_dgrad
<
input_t
,
output_t
,
acc_t
,
10
,
is_log_softmax
>
<<<
blocks
,
threads
,
0
,
streamid
>>>
(
grad_input
,
grad
,
output
,
mask
,
pad_mask
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
,
heads
);
break
;
default:
break
;
}
}
}
template
<
typename
input_t
,
typename
output_t
,
typename
acc_t
,
int
log2_elements
,
bool
is_log_softmax
>
__global__
void
masked_scale_softmax_warp_backward
(
output_t
*
gradInput
,
const
input_t
*
grad
,
const
input_t
*
output
,
const
uint8_t
*
mask
,
acc_t
scale
,
int
batch_size
,
int
stride
,
int
element_count
)
{
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
// warp_size of method warp_softmax_backward_kernel.
constexpr
int
next_power_of_two
=
1
<<
log2_elements
;
constexpr
int
WARP_SIZE
=
(
next_power_of_two
<
C10_WARP_SIZE
)
?
next_power_of_two
:
C10_WARP_SIZE
;
constexpr
int
WARP_ITERATIONS
=
next_power_of_two
/
WARP_SIZE
;
constexpr
int
WARP_BATCH
=
(
next_power_of_two
<=
128
)
?
2
:
1
;
int
first_batch
=
(
blockDim
.
y
*
blockIdx
.
x
+
threadIdx
.
y
)
*
WARP_BATCH
;
// batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP.
int
local_batches
=
batch_size
-
first_batch
;
if
(
local_batches
>
WARP_BATCH
)
local_batches
=
WARP_BATCH
;
// there might be multiple batches per warp. compute the index within the
// batch
int
local_idx
=
threadIdx
.
x
%
WARP_SIZE
;
// the first element to process by the current thread
int
thread_offset
=
first_batch
*
stride
+
local_idx
;
grad
+=
thread_offset
;
output
+=
thread_offset
;
gradInput
+=
thread_offset
;
mask
+=
thread_offset
;
// The nested loops over WARP_BATCH and then WARP_ITERATIONS can be simplified
// to one loop, but I think doing so would obfuscate the logic of the
// algorithm, thus I chose to keep the nested loops. This should have no
// impact on performance because the loops are unrolled anyway.
// load data from global memory
acc_t
grad_reg
[
WARP_BATCH
][
WARP_ITERATIONS
];
acc_t
output_reg
[
WARP_BATCH
][
WARP_ITERATIONS
];
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
int
batch_element_count
=
(
i
>=
local_batches
)
?
0
:
element_count
;
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
++
it
)
{
int
element_index
=
local_idx
+
it
*
WARP_SIZE
;
if
(
element_index
<
batch_element_count
)
{
grad_reg
[
i
][
it
]
=
(
input_t
)((
acc_t
)
mask
[
i
*
element_count
+
it
*
WARP_SIZE
]
*
(
acc_t
)
grad
[
i
*
element_count
+
it
*
WARP_SIZE
]
*
(
acc_t
)
scale
)
*
output
[
i
*
element_count
+
it
*
WARP_SIZE
];
output_reg
[
i
][
it
]
=
output
[
i
*
element_count
+
it
*
WARP_SIZE
];
}
else
{
grad_reg
[
i
][
it
]
=
acc_t
(
0
);
output_reg
[
i
][
it
]
=
acc_t
(
0
);
}
}
}
acc_t
sum
[
WARP_BATCH
];
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
sum
[
i
]
=
grad_reg
[
i
][
0
];
#pragma unroll
for
(
int
it
=
1
;
it
<
WARP_ITERATIONS
;
++
it
)
{
sum
[
i
]
+=
grad_reg
[
i
][
it
];
}
}
warp_reduce_sum
<
acc_t
,
WARP_BATCH
,
WARP_SIZE
>
(
sum
);
// store result
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
if
(
i
>=
local_batches
)
break
;
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
++
it
)
{
int
element_index
=
local_idx
+
it
*
WARP_SIZE
;
if
(
element_index
<
element_count
)
{
// compute gradients
if
(
is_log_softmax
)
{
gradInput
[
i
*
element_count
+
it
*
WARP_SIZE
]
=
(
grad_reg
[
i
][
it
]
-
std
::
exp
(
output_reg
[
i
][
it
])
*
sum
[
i
]);
}
else
{
gradInput
[
i
*
element_count
+
it
*
WARP_SIZE
]
=
(
grad_reg
[
i
][
it
]
-
output_reg
[
i
][
it
]
*
sum
[
i
]);
}
}
}
}
}
template
<
typename
input_t
,
typename
output_t
,
typename
acc_t
,
int
WARP_BATCH
,
int
WARP_ITERATIONS
,
int
WARP_SIZE
=
32
,
int
ELEMENTS_PER_LDG_STG
,
bool
is_log_softmax
>
__global__
void
masked_scale_softmax_warp_backward_recompute
(
output_t
*
gradInput
,
const
input_t
*
grad
,
const
input_t
*
softmax_input
,
const
input_t
*
pad_mask
,
const
uint8_t
*
mask
,
acc_t
scale
,
int
batch_size
,
int
stride
,
int
pad_batch_stride
,
int
element_count
)
{
int
first_batch
=
(
blockDim
.
y
*
blockIdx
.
x
+
threadIdx
.
y
)
*
WARP_BATCH
;
// batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP.
int
local_batches
=
batch_size
-
first_batch
;
if
(
local_batches
>
WARP_BATCH
)
local_batches
=
WARP_BATCH
;
// there might be multiple batches per warp. compute the index within the
// batch
int
local_idx
=
threadIdx
.
x
%
WARP_SIZE
;
// vectorize if a row length is multiple of 4
int
flag_vec4
=
element_count
&
3
==
0
;
acc_t
grad_reg
[
WARP_BATCH
][
WARP_ITERATIONS
];
input_t
elements_input
[
WARP_BATCH
][
WARP_ITERATIONS
];
// the first element to process by the current thread
int
thread_offset
=
first_batch
*
stride
+
ELEMENTS_PER_LDG_STG
*
local_idx
;
grad
+=
thread_offset
;
softmax_input
+=
thread_offset
;
gradInput
+=
thread_offset
;
mask
+=
thread_offset
;
// The nested loops over WARP_BATCH and then WARP_ITERATIONS can be simplified
// to one loop, but I think doing so would obfuscate the logic of the
// algorithm, thus I chose to keep the nested loops. This should have no
// impact on performance because the loops are unrolled anyway.
// load data from global memory
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
int
batch_element_count
=
(
i
>=
local_batches
)
?
0
:
element_count
;
int
pad_thread_offset
=
((
first_batch
+
i
)
/
pad_batch_stride
)
*
stride
+
ELEMENTS_PER_LDG_STG
*
local_idx
;
const
input_t
*
curr_mask
=
pad_mask
+
pad_thread_offset
;
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
it
+=
ELEMENTS_PER_LDG_STG
)
{
int
element_index
=
ELEMENTS_PER_LDG_STG
*
local_idx
+
it
*
WARP_SIZE
;
#pragma unroll
for
(
int
element
=
0
;
element
<
ELEMENTS_PER_LDG_STG
;
++
element
)
{
// masking_value is a large negative value
elements_input
[
i
][
it
+
element
]
=
-
10000
;
grad_reg
[
i
][
it
+
element
]
=
acc_t
(
0
);
}
if
(
element_index
<
batch_element_count
)
{
int
itr_jmp
=
it
*
WARP_SIZE
;
int
itr_idx
=
i
*
element_count
+
itr_jmp
;
copy_vector
<
input_t
,
ELEMENTS_PER_LDG_STG
>
(
&
elements_input
[
i
][
it
],
softmax_input
+
itr_idx
);
apply_additive_mask
<
input_t
,
ELEMENTS_PER_LDG_STG
>
(
&
elements_input
[
i
][
it
],
curr_mask
+
itr_jmp
);
//(__half)-std::numeric_limits<float>::infinity()
uint8_t
mask_temp
[
ELEMENTS_PER_LDG_STG
];
input_t
grad_temp
[
ELEMENTS_PER_LDG_STG
];
copy_vector
<
uint8_t
,
ELEMENTS_PER_LDG_STG
>
(
&
mask_temp
[
0
],
mask
+
itr_idx
);
copy_vector
<
input_t
,
ELEMENTS_PER_LDG_STG
>
(
&
grad_temp
[
0
],
grad
+
itr_idx
);
#pragma unroll
for
(
int
element
=
0
;
element
<
ELEMENTS_PER_LDG_STG
;
++
element
)
{
grad_reg
[
i
][
it
+
element
]
=
((
acc_t
)
mask_temp
[
element
]
*
(
acc_t
)
grad_temp
[
element
]
*
(
acc_t
)
scale
);
}
}
}
}
// load data from global memory
// convert input_t to acc_t
// TODO : remove this, input is already acc_t type in register
acc_t
elements
[
WARP_BATCH
][
WARP_ITERATIONS
];
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
++
it
)
{
elements
[
i
][
it
]
=
elements_input
[
i
][
it
];
}
}
constexpr
uint32_t
FULL_MASK
=
0xffffffff
;
// compute local max_value
// take the max_value of the first element to avoid one max call
acc_t
max_value
[
WARP_BATCH
];
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
max_value
[
i
]
=
elements
[
i
][
0
];
}
#pragma unroll
for
(
int
it
=
1
;
it
<
WARP_ITERATIONS
;
++
it
)
{
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
max_value
[
i
]
=
(
max_value
[
i
]
>
elements
[
i
][
it
])
?
max_value
[
i
]
:
elements
[
i
][
it
];
}
}
// reduction max_value
#pragma unroll
for
(
int
offset
=
WARP_SIZE
/
2
;
offset
>
0
;
offset
/=
2
)
{
float
val
[
WARP_BATCH
];
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
val
[
i
]
=
APEX_WARP_SHFL_XOR
(
FULL_MASK
,
max_value
[
i
],
offset
,
WARP_SIZE
);
}
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
max_value
[
i
]
=
max_value
[
i
]
>
val
[
i
]
?
max_value
[
i
]
:
val
[
i
];
}
}
// compute local sum
acc_t
sum
[
WARP_BATCH
]{
0.0
f
};
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
++
it
)
{
// elements[i][it] = expf(elements[i][it] - max_value[i]);
elements
[
i
][
it
]
=
std
::
exp
(
elements
[
i
][
it
]
-
max_value
[
i
]);
sum
[
i
]
+=
elements
[
i
][
it
];
}
}
// reduction sum
#pragma unroll
for
(
int
offset
=
WARP_SIZE
/
2
;
offset
>
0
;
offset
/=
2
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
sum
[
i
]
+=
APEX_WARP_SHFL_XOR
(
FULL_MASK
,
sum
[
i
],
offset
,
WARP_SIZE
);
}
}
// store result
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
if
(
i
>=
local_batches
)
break
;
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
it
++
)
{
elements
[
i
][
it
]
=
elements
[
i
][
it
]
/
sum
[
i
];
grad_reg
[
i
][
it
]
=
grad_reg
[
i
][
it
]
*
elements
[
i
][
it
];
}
}
acc_t
grad_sum
[
WARP_BATCH
];
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
grad_sum
[
i
]
=
grad_reg
[
i
][
0
];
#pragma unroll
for
(
int
it
=
1
;
it
<
WARP_ITERATIONS
;
++
it
)
{
grad_sum
[
i
]
+=
grad_reg
[
i
][
it
];
}
}
warp_reduce_sum
<
acc_t
,
WARP_BATCH
,
WARP_SIZE
>
(
grad_sum
);
// store result
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
if
(
i
>=
local_batches
)
break
;
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
it
+=
ELEMENTS_PER_LDG_STG
)
{
int
element_index
=
ELEMENTS_PER_LDG_STG
*
local_idx
+
it
*
WARP_SIZE
;
if
(
element_index
<
element_count
)
{
// compute gradients
output_t
grad_input_reg
[
ELEMENTS_PER_LDG_STG
];
#pragma unroll
for
(
int
element
=
0
;
element
<
ELEMENTS_PER_LDG_STG
;
element
++
)
{
if
(
is_log_softmax
)
{
grad_input_reg
[
element
]
=
(
grad_reg
[
i
][
it
+
element
]
-
std
::
exp
(
elements
[
i
][
it
+
element
])
*
grad_sum
[
i
]);
}
else
{
grad_input_reg
[
element
]
=
(
grad_reg
[
i
][
it
+
element
]
-
elements
[
i
][
it
+
element
]
*
grad_sum
[
i
]);
}
}
copy_vector
<
output_t
,
ELEMENTS_PER_LDG_STG
>
(
gradInput
+
i
*
element_count
+
it
*
WARP_SIZE
,
grad_input_reg
);
}
}
}
}
template
<
typename
input_t
,
typename
output_t
,
typename
acc_t
,
bool
is_log_softmax
>
using
masked_scale_softmax_warp_backward_recompute_func
=
void
(
*
)(
output_t
*
gradInput
,
const
input_t
*
grad
,
const
input_t
*
softmax_input
,
const
input_t
*
pad_mask
,
const
uint8_t
*
mask
,
acc_t
scale
,
int
batch_size
,
int
stride
,
int
pad_batch_stride
,
int
element_count
);
template
<
typename
input_t
,
typename
output_t
,
typename
acc_t
,
bool
is_log_softmax
>
bool
masked_scale_softmax_warp_backward_recompute_kernel
(
int
element_count
,
int
log2_elements
,
int
&
warp_size
,
int
&
batches_per_warp
,
masked_scale_softmax_warp_backward_recompute_func
<
input_t
,
output_t
,
acc_t
,
is_log_softmax
>
&
kernel
)
{
// determine size of a warp
const
int
next_power_of_two
=
1
<<
log2_elements
;
warp_size
=
(
next_power_of_two
<
32
)
?
next_power_of_two
:
32
;
// determine how many batches a warp should process.
batches_per_warp
=
(
next_power_of_two
<=
128
)
?
2
:
1
;
bool
flag_vec4
=
(
element_count
%
4
==
0
);
switch
(
log2_elements
)
{
case
0
:
// 1
kernel
=
&
masked_scale_softmax_warp_backward_recompute
<
input_t
,
output_t
,
acc_t
,
2
,
1
,
1
,
1
,
is_log_softmax
>
;
break
;
case
1
:
// 2
kernel
=
&
masked_scale_softmax_warp_backward_recompute
<
input_t
,
output_t
,
acc_t
,
2
,
1
,
2
,
1
,
is_log_softmax
>
;
break
;
case
2
:
// 4
kernel
=
&
masked_scale_softmax_warp_backward_recompute
<
input_t
,
output_t
,
acc_t
,
2
,
1
,
4
,
1
,
is_log_softmax
>
;
break
;
case
3
:
// 8
kernel
=
&
masked_scale_softmax_warp_backward_recompute
<
input_t
,
output_t
,
acc_t
,
2
,
1
,
8
,
1
,
is_log_softmax
>
;
break
;
case
4
:
// 16
kernel
=
&
masked_scale_softmax_warp_backward_recompute
<
input_t
,
output_t
,
acc_t
,
2
,
1
,
16
,
1
,
is_log_softmax
>
;
break
;
case
5
:
// 32
kernel
=
&
masked_scale_softmax_warp_backward_recompute
<
input_t
,
output_t
,
acc_t
,
2
,
1
,
32
,
1
,
is_log_softmax
>
;
break
;
case
6
:
// 64
kernel
=
&
masked_scale_softmax_warp_backward_recompute
<
input_t
,
output_t
,
acc_t
,
2
,
2
,
32
,
1
,
is_log_softmax
>
;
break
;
case
7
:
// 128
kernel
=
&
masked_scale_softmax_warp_backward_recompute
<
input_t
,
output_t
,
acc_t
,
2
,
4
,
32
,
1
,
is_log_softmax
>
;
break
;
case
8
:
// 256
if
(
flag_vec4
)
kernel
=
&
masked_scale_softmax_warp_backward_recompute
<
input_t
,
output_t
,
acc_t
,
1
,
8
,
32
,
4
,
is_log_softmax
>
;
else
kernel
=
&
masked_scale_softmax_warp_backward_recompute
<
input_t
,
output_t
,
acc_t
,
1
,
8
,
32
,
1
,
is_log_softmax
>
;
break
;
case
9
:
// 512
if
(
flag_vec4
)
kernel
=
&
masked_scale_softmax_warp_backward_recompute
<
input_t
,
output_t
,
acc_t
,
1
,
16
,
32
,
4
,
is_log_softmax
>
;
else
kernel
=
&
masked_scale_softmax_warp_backward_recompute
<
input_t
,
output_t
,
acc_t
,
1
,
16
,
32
,
1
,
is_log_softmax
>
;
break
;
case
10
:
// 1024
if
(
flag_vec4
)
kernel
=
&
masked_scale_softmax_warp_backward_recompute
<
input_t
,
output_t
,
acc_t
,
1
,
32
,
32
,
4
,
is_log_softmax
>
;
else
kernel
=
&
masked_scale_softmax_warp_backward_recompute
<
input_t
,
output_t
,
acc_t
,
1
,
32
,
32
,
1
,
is_log_softmax
>
;
break
;
case
11
:
// 2048
if
(
flag_vec4
)
kernel
=
&
masked_scale_softmax_warp_backward_recompute
<
input_t
,
output_t
,
acc_t
,
1
,
64
,
32
,
4
,
is_log_softmax
>
;
else
kernel
=
&
masked_scale_softmax_warp_backward_recompute
<
input_t
,
output_t
,
acc_t
,
1
,
64
,
32
,
1
,
is_log_softmax
>
;
break
;
default:
return
false
;
}
return
true
;
}
template
<
typename
input_t
,
typename
output_t
,
typename
acc_t
,
bool
is_log_softmax
>
bool
dispatch_masked_scale_softmax_backward_recompute
(
output_t
*
grad_input
,
const
input_t
*
grad
,
const
input_t
*
softmax_input
,
const
input_t
*
pad_mask
,
const
uint8_t
*
mask
,
acc_t
scale
,
int
softmax_elements
,
int
softmax_elements_stride
,
int
pad_batch_stride
,
int
batch_count
,
cudaStream_t
streamid
)
{
if
(
softmax_elements
==
0
)
{
return
true
;
}
else
if
(
softmax_elements
<=
2048
)
{
// compute function index. there's a function for each power of two size up
// to 1024.
int
log2_elements
=
0
;
while
((
1
<<
log2_elements
)
<
softmax_elements
)
++
log2_elements
;
masked_scale_softmax_warp_backward_recompute_func
<
input_t
,
output_t
,
acc_t
,
is_log_softmax
>
kernel
;
int
warp_size
,
batches_per_warp
;
if
(
!
masked_scale_softmax_warp_backward_recompute_kernel
<
input_t
,
output_t
,
acc_t
,
is_log_softmax
>
(
softmax_elements
,
log2_elements
,
warp_size
,
batches_per_warp
,
kernel
))
{
return
false
;
}
// use 128 threads per block to maximimize gpu utilization
constexpr
int
threads_per_block
=
128
;
// compute warps per block.
int
warps_per_block
=
(
threads_per_block
/
warp_size
);
int
batches_per_block
=
warps_per_block
*
batches_per_warp
;
int
blocks
=
(
batch_count
+
batches_per_block
-
1
)
/
batches_per_block
;
// compute launch size
dim3
threads
(
warp_size
,
warps_per_block
,
1
);
// launch
kernel
<<<
blocks
,
threads
,
0
,
streamid
>>>
(
grad_input
,
grad
,
softmax_input
,
pad_mask
,
mask
,
scale
,
batch_count
,
softmax_elements_stride
,
pad_batch_stride
,
softmax_elements
);
return
true
;
}
return
false
;
}
template
<
typename
input_t
,
typename
output_t
,
typename
acc_t
,
bool
is_log_softmax
>
void
dispatch_masked_scale_softmax_backward_stream
(
output_t
*
grad_input
,
const
input_t
*
grad
,
const
input_t
*
output
,
const
uint8_t
*
mask
,
acc_t
scale
,
int
softmax_elements
,
int
softmax_elements_stride
,
int
batch_count
,
cudaStream_t
streamid
)
{
TORCH_INTERNAL_ASSERT
(
softmax_elements
>=
0
&&
softmax_elements
<=
1024
);
if
(
softmax_elements
==
0
)
{
return
;
}
else
{
int
log2_elements
=
log2_ceil_native
(
softmax_elements
);
const
int
next_power_of_two
=
1
<<
log2_elements
;
// This value must match the WARP_SIZE constexpr value computed inside
// softmax_warp_backward.
int
warp_size
=
(
next_power_of_two
<
C10_WARP_SIZE
)
?
next_power_of_two
:
C10_WARP_SIZE
;
// This value must match the WARP_BATCH constexpr value computed inside
// softmax_warp_backward.
int
batches_per_warp
=
(
next_power_of_two
<=
128
)
?
2
:
1
;
// use 128 threads per block to maximimize gpu utilization
constexpr
int
threads_per_block
=
128
;
int
warps_per_block
=
(
threads_per_block
/
warp_size
);
int
batches_per_block
=
warps_per_block
*
batches_per_warp
;
int
blocks
=
(
batch_count
+
batches_per_block
-
1
)
/
batches_per_block
;
dim3
threads
(
warp_size
,
warps_per_block
,
1
);
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
switch
(
log2_elements
)
{
case
0
:
// 1
masked_scale_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
0
,
is_log_softmax
>
<<<
blocks
,
threads
,
0
,
streamid
>>>
(
grad_input
,
grad
,
output
,
mask
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
1
:
// 2
masked_scale_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
1
,
is_log_softmax
>
<<<
blocks
,
threads
,
0
,
streamid
>>>
(
grad_input
,
grad
,
output
,
mask
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
2
:
// 4
masked_scale_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
2
,
is_log_softmax
>
<<<
blocks
,
threads
,
0
,
streamid
>>>
(
grad_input
,
grad
,
output
,
mask
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
3
:
// 8
masked_scale_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
3
,
is_log_softmax
>
<<<
blocks
,
threads
,
0
,
streamid
>>>
(
grad_input
,
grad
,
output
,
mask
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
4
:
// 16
masked_scale_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
4
,
is_log_softmax
>
<<<
blocks
,
threads
,
0
,
streamid
>>>
(
grad_input
,
grad
,
output
,
mask
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
5
:
// 32
masked_scale_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
5
,
is_log_softmax
>
<<<
blocks
,
threads
,
0
,
streamid
>>>
(
grad_input
,
grad
,
output
,
mask
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
6
:
// 64
masked_scale_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
6
,
is_log_softmax
>
<<<
blocks
,
threads
,
0
,
streamid
>>>
(
grad_input
,
grad
,
output
,
mask
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
7
:
// 128
masked_scale_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
7
,
is_log_softmax
>
<<<
blocks
,
threads
,
0
,
streamid
>>>
(
grad_input
,
grad
,
output
,
mask
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
8
:
// 256
masked_scale_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
8
,
is_log_softmax
>
<<<
blocks
,
threads
,
0
,
streamid
>>>
(
grad_input
,
grad
,
output
,
mask
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
9
:
// 512
masked_scale_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
9
,
is_log_softmax
>
<<<
blocks
,
threads
,
0
,
streamid
>>>
(
grad_input
,
grad
,
output
,
mask
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
10
:
// 1024
masked_scale_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
10
,
is_log_softmax
>
<<<
blocks
,
threads
,
0
,
streamid
>>>
(
grad_input
,
grad
,
output
,
mask
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
default:
break
;
}
}
}
// elementwise multiplication called in at::softmax_backward_data is fused
// inside softmax dgrad kernel as a result of fusion, intermediate
// multiplication result is stored in fp32 in registers, instead of fp16
template
<
typename
input_t
,
typename
output_t
,
typename
acc_t
,
int
log2_elements
,
bool
is_log_softmax
>
__global__
void
softmax_warp_backward_fused_native
(
output_t
*
gradInput
,
const
input_t
*
grad
,
const
input_t
*
output
,
int
batch_size
,
int
stride
,
int
element_count
)
{
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
// warp_size of method warp_softmax_backward_kernel.
constexpr
int
next_power_of_two
=
1
<<
log2_elements
;
constexpr
int
WARP_SIZE
=
(
next_power_of_two
<
C10_WARP_SIZE
)
?
next_power_of_two
:
C10_WARP_SIZE
;
constexpr
int
WARP_ITERATIONS
=
next_power_of_two
/
WARP_SIZE
;
constexpr
int
WARP_BATCH
=
(
next_power_of_two
<=
128
)
?
2
:
1
;
int
first_batch
=
(
blockDim
.
y
*
blockIdx
.
x
+
threadIdx
.
y
)
*
WARP_BATCH
;
// batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP.
int
local_batches
=
batch_size
-
first_batch
;
if
(
local_batches
>
WARP_BATCH
)
local_batches
=
WARP_BATCH
;
// there might be multiple batches per warp. compute the index within the
// batch
int
local_idx
=
threadIdx
.
x
%
WARP_SIZE
;
// the first element to process by the current thread
int
thread_offset
=
first_batch
*
stride
+
local_idx
;
grad
+=
thread_offset
;
output
+=
thread_offset
;
gradInput
+=
thread_offset
;
// The nested loops over WARP_BATCH and then WARP_ITERATIONS can be simplified
// to one loop, but I think doing so would obfuscate the logic of the
// algorithm, thus I chose to keep the nested loops. This should have no
// impact on performance because the loops are unrolled anyway.
// load data from global memory
acc_t
grad_reg
[
WARP_BATCH
][
WARP_ITERATIONS
];
acc_t
output_reg
[
WARP_BATCH
][
WARP_ITERATIONS
];
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
int
batch_element_count
=
(
i
>=
local_batches
)
?
0
:
element_count
;
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
++
it
)
{
int
element_index
=
local_idx
+
it
*
WARP_SIZE
;
if
(
element_index
<
batch_element_count
)
{
grad_reg
[
i
][
it
]
=
grad
[
i
*
element_count
+
it
*
WARP_SIZE
]
*
output
[
i
*
element_count
+
it
*
WARP_SIZE
];
output_reg
[
i
][
it
]
=
output
[
i
*
element_count
+
it
*
WARP_SIZE
];
}
else
{
grad_reg
[
i
][
it
]
=
acc_t
(
0
);
output_reg
[
i
][
it
]
=
acc_t
(
0
);
}
}
}
acc_t
sum
[
WARP_BATCH
];
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
sum
[
i
]
=
grad_reg
[
i
][
0
];
//* output_reg[i][0];
#pragma unroll
for
(
int
it
=
1
;
it
<
WARP_ITERATIONS
;
++
it
)
{
sum
[
i
]
+=
grad_reg
[
i
][
it
];
// * output_reg[i][it];
}
}
warp_reduce_sum
<
acc_t
,
WARP_BATCH
,
WARP_SIZE
>
(
sum
);
// store result
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
if
(
i
>=
local_batches
)
break
;
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
++
it
)
{
int
element_index
=
local_idx
+
it
*
WARP_SIZE
;
if
(
element_index
<
element_count
)
{
// compute gradients
if
(
is_log_softmax
)
{
gradInput
[
i
*
element_count
+
it
*
WARP_SIZE
]
=
(
grad_reg
[
i
][
it
]
-
std
::
exp
(
output_reg
[
i
][
it
])
*
sum
[
i
]);
}
else
{
gradInput
[
i
*
element_count
+
it
*
WARP_SIZE
]
=
(
grad_reg
[
i
][
it
]
-
output_reg
[
i
][
it
]
*
sum
[
i
]);
}
}
}
}
}
template
<
typename
input_t
,
typename
output_t
,
typename
acc_t
,
bool
is_log_softmax
>
void
dispatch_softmax_backward_fused_native
(
output_t
*
grad_input
,
const
input_t
*
grad
,
const
input_t
*
output
,
int
softmax_elements
,
int
softmax_elements_stride
,
int
batch_count
)
{
TORCH_INTERNAL_ASSERT
(
softmax_elements
>=
0
&&
softmax_elements
<=
1024
);
if
(
softmax_elements
==
0
)
{
return
;
}
else
{
int
log2_elements
=
log2_ceil_native
(
softmax_elements
);
const
int
next_power_of_two
=
1
<<
log2_elements
;
// This value must match the WARP_SIZE constexpr value computed inside
// softmax_warp_backward.
int
warp_size
=
(
next_power_of_two
<
C10_WARP_SIZE
)
?
next_power_of_two
:
C10_WARP_SIZE
;
// This value must match the WARP_BATCH constexpr value computed inside
// softmax_warp_backward.
int
batches_per_warp
=
(
next_power_of_two
<=
128
)
?
2
:
1
;
// use 128 threads per block to maximimize gpu utilization
constexpr
int
threads_per_block
=
128
;
int
warps_per_block
=
(
threads_per_block
/
warp_size
);
int
batches_per_block
=
warps_per_block
*
batches_per_warp
;
int
blocks
=
(
batch_count
+
batches_per_block
-
1
)
/
batches_per_block
;
dim3
threads
(
warp_size
,
warps_per_block
,
1
);
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
switch
(
log2_elements
)
{
case
0
:
// 1
softmax_warp_backward_fused_native
<
input_t
,
output_t
,
acc_t
,
0
,
is_log_softmax
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
1
:
// 2
softmax_warp_backward_fused_native
<
input_t
,
output_t
,
acc_t
,
1
,
is_log_softmax
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
2
:
// 4
softmax_warp_backward_fused_native
<
input_t
,
output_t
,
acc_t
,
2
,
is_log_softmax
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
3
:
// 8
softmax_warp_backward_fused_native
<
input_t
,
output_t
,
acc_t
,
3
,
is_log_softmax
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
4
:
// 16
softmax_warp_backward_fused_native
<
input_t
,
output_t
,
acc_t
,
4
,
is_log_softmax
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
5
:
// 32
softmax_warp_backward_fused_native
<
input_t
,
output_t
,
acc_t
,
5
,
is_log_softmax
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
6
:
// 64
softmax_warp_backward_fused_native
<
input_t
,
output_t
,
acc_t
,
6
,
is_log_softmax
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
7
:
// 128
softmax_warp_backward_fused_native
<
input_t
,
output_t
,
acc_t
,
7
,
is_log_softmax
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
8
:
// 256
softmax_warp_backward_fused_native
<
input_t
,
output_t
,
acc_t
,
8
,
is_log_softmax
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
9
:
// 512
softmax_warp_backward_fused_native
<
input_t
,
output_t
,
acc_t
,
9
,
is_log_softmax
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
10
:
// 1024
softmax_warp_backward_fused_native
<
input_t
,
output_t
,
acc_t
,
10
,
is_log_softmax
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
default:
break
;
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
// Warp softmax backward
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
input_t
,
typename
output_t
,
typename
acc_t
,
int
WARP_BATCH
,
int
WARP_ITERATIONS
,
int
WARP_SIZE
=
32
,
int
ELEMENTS_PER_LDG_STG
=
1
>
__global__
void
softmax_warp_backward
(
__half
*
gradInput
,
const
__half
*
grad
,
const
__half
*
output
,
int
batch_size
,
int
stride
,
int
element_count
)
{
int
first_batch
=
(
blockDim
.
y
*
blockIdx
.
x
+
threadIdx
.
y
)
*
WARP_BATCH
;
// batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP.
int
local_batches
=
batch_size
-
first_batch
;
if
(
local_batches
>
WARP_BATCH
)
local_batches
=
WARP_BATCH
;
// there might be multiple batches per warp. compute the index within the
// batch
int
local_idx
=
threadIdx
.
x
;
// the first element to process by the current thread
int
thread_offset
=
first_batch
*
stride
+
ELEMENTS_PER_LDG_STG
*
local_idx
;
grad
+=
thread_offset
;
output
+=
thread_offset
;
gradInput
+=
thread_offset
;
// load data from global memory
input_t
grad_reg_input
[
WARP_BATCH
][
WARP_ITERATIONS
]
=
{
0.0
f
};
input_t
output_reg_input
[
WARP_BATCH
][
WARP_ITERATIONS
]
=
{
0.0
f
};
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
int
batch_element_count
=
(
i
>=
local_batches
)
?
0
:
element_count
;
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
it
+=
ELEMENTS_PER_LDG_STG
)
{
int
element_index
=
ELEMENTS_PER_LDG_STG
*
local_idx
+
it
*
WARP_SIZE
;
if
(
element_index
<
batch_element_count
)
{
copy_vector
<
input_t
,
ELEMENTS_PER_LDG_STG
>
(
&
grad_reg_input
[
i
][
it
],
grad
+
i
*
element_count
+
it
*
WARP_SIZE
);
copy_vector
<
input_t
,
ELEMENTS_PER_LDG_STG
>
(
&
output_reg_input
[
i
][
it
],
output
+
i
*
element_count
+
it
*
WARP_SIZE
);
}
}
}
// convert half to floating point
acc_t
grad_reg
[
WARP_BATCH
][
WARP_ITERATIONS
];
acc_t
output_reg
[
WARP_BATCH
][
WARP_ITERATIONS
];
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
++
it
)
{
grad_reg
[
i
][
it
]
=
grad_reg_input
[
i
][
it
];
output_reg
[
i
][
it
]
=
output_reg_input
[
i
][
it
];
}
}
// compute thread local sum
acc_t
sum
[
WARP_BATCH
]
=
{
0
};
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
++
it
)
{
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
sum
[
i
]
+=
grad_reg
[
i
][
it
]
*
output_reg
[
i
][
it
];
}
}
// reduction sum
constexpr
uint32_t
FULL_MASK
=
0xffffffff
;
#pragma unroll
for
(
int
offset
=
WARP_SIZE
/
2
;
offset
>
0
;
offset
/=
2
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
sum
[
i
]
+=
APEX_WARP_SHFL_XOR
(
FULL_MASK
,
sum
[
i
],
offset
,
WARP_SIZE
);
}
}
// store result
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
if
(
i
>=
local_batches
)
break
;
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
it
+=
ELEMENTS_PER_LDG_STG
)
{
int
element_index
=
ELEMENTS_PER_LDG_STG
*
local_idx
+
it
*
WARP_SIZE
;
if
(
element_index
<
element_count
)
{
// compute gradients
output_t
out
[
ELEMENTS_PER_LDG_STG
];
for
(
int
element
=
0
;
element
<
ELEMENTS_PER_LDG_STG
;
++
element
)
{
out
[
element
]
=
(
output_reg
[
i
][
it
+
element
]
*
(
grad_reg
[
i
][
it
+
element
]
-
sum
[
i
]));
}
// store them in global memory
copy_vector
<
output_t
,
ELEMENTS_PER_LDG_STG
>
(
gradInput
+
i
*
element_count
+
it
*
WARP_SIZE
,
out
);
}
}
}
}
// WARP_BATCH number of batches.
// WARP_ITERATOINS The number of iterations required for one warp to iterate
// over all data. WARP_SIZE number of elements working on a single batch, has to
// be a power of two. ELEMENTS_PER_LDG_STG has to be 1.
template
<
typename
input_t
,
typename
output_t
>
using
softmax_backward_func
=
void
(
*
)(
output_t
*
gradInput
,
const
input_t
*
grad
,
const
input_t
*
output
,
int
batch_size
,
int
stride
,
int
element_count
);
template
<
typename
input_t
,
typename
output_t
,
typename
acc_t
>
bool
warp_softmax_backward_kernel
(
int
log2_elements
,
int
&
warp_size
,
int
&
batches_per_warp
,
softmax_backward_func
<
input_t
,
output_t
>
&
kernel
)
{
// determine size of a warp
const
int
next_power_of_two
=
1
<<
log2_elements
;
warp_size
=
(
next_power_of_two
<
32
)
?
next_power_of_two
:
32
;
// determine how many batches a warp should process.
batches_per_warp
=
(
next_power_of_two
<=
128
)
?
2
:
1
;
switch
(
log2_elements
)
{
case
0
:
// 1
kernel
=
&
softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
2
,
1
,
1
,
1
>
;
break
;
case
1
:
// 2
kernel
=
&
softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
2
,
1
,
2
,
1
>
;
break
;
case
2
:
// 4
kernel
=
&
softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
2
,
1
,
4
,
1
>
;
break
;
case
3
:
// 8
kernel
=
&
softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
2
,
1
,
8
,
1
>
;
break
;
case
4
:
// 16
kernel
=
&
softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
2
,
1
,
16
,
1
>
;
break
;
case
5
:
// 32
kernel
=
&
softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
2
,
1
,
32
,
1
>
;
break
;
case
6
:
// 64
kernel
=
&
softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
2
,
2
,
32
,
1
>
;
break
;
case
7
:
// 128
kernel
=
&
softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
2
,
4
,
32
,
1
>
;
break
;
case
8
:
// 256
kernel
=
&
softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
1
,
8
,
32
,
1
>
;
break
;
case
9
:
// 512
kernel
=
&
softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
1
,
16
,
32
,
1
>
;
break
;
case
10
:
// 1024
kernel
=
&
softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
1
,
32
,
32
,
1
>
;
break
;
default:
return
false
;
}
return
true
;
}
template
<
typename
input_t
,
typename
output_t
,
typename
acc_t
>
bool
dispatch_softmax_backward
(
output_t
*
grad_input
,
const
input_t
*
grad
,
const
input_t
*
output
,
int
softmax_elements
,
int
softmax_elements_stride
,
int
batch_count
)
{
if
(
softmax_elements
==
0
)
{
return
true
;
}
else
if
(
softmax_elements
<=
1024
)
{
// compute function index. there's a function for each power of two size up
// to 1024.
int
log2_elements
=
0
;
while
((
1
<<
log2_elements
)
<
softmax_elements
)
++
log2_elements
;
softmax_backward_func
<
input_t
,
output_t
>
kernel
;
int
warp_size
,
batches_per_warp
;
if
(
!
warp_softmax_backward_kernel
<
input_t
,
output_t
,
acc_t
>
(
log2_elements
,
warp_size
,
batches_per_warp
,
kernel
))
{
return
false
;
}
// use 128 threads per block to maximimize gpu utilization
constexpr
int
threads_per_block
=
128
;
// compute warps per block.
int
warps_per_block
=
(
threads_per_block
/
warp_size
);
// compute launch size
int
batches_per_block
=
warps_per_block
*
batches_per_warp
;
int
blocks
=
(
batch_count
+
batches_per_block
-
1
)
/
batches_per_block
;
dim3
threads
(
warp_size
,
warps_per_block
,
1
);
// launch
kernel
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
return
true
;
}
return
false
;
}
template
<
typename
input_t
,
typename
output_t
,
typename
acc_t
>
bool
dispatch_softmax_backward_stream
(
output_t
*
grad_input
,
const
input_t
*
grad
,
const
input_t
*
output
,
int
softmax_elements
,
int
softmax_elements_stride
,
int
batch_count
,
cudaStream_t
streamid
)
{
if
(
softmax_elements
==
0
)
{
return
true
;
}
else
if
(
softmax_elements
<=
1024
)
{
// compute function index. there's a function for each power of two size up
// to 1024.
int
log2_elements
=
0
;
while
((
1
<<
log2_elements
)
<
softmax_elements
)
++
log2_elements
;
softmax_backward_func
<
input_t
,
output_t
>
kernel
;
int
warp_size
,
batches_per_warp
;
if
(
!
warp_softmax_backward_kernel
<
input_t
,
output_t
,
acc_t
>
(
log2_elements
,
warp_size
,
batches_per_warp
,
kernel
))
{
return
false
;
}
// use 128 threads per block to maximimize gpu utilization
constexpr
int
threads_per_block
=
128
;
// compute warps per block.
int
warps_per_block
=
(
threads_per_block
/
warp_size
);
// compute launch size
int
batches_per_block
=
warps_per_block
*
batches_per_warp
;
int
blocks
=
(
batch_count
+
batches_per_block
-
1
)
/
batches_per_block
;
dim3
threads
(
warp_size
,
warps_per_block
,
1
);
// launch
kernel
<<<
blocks
,
threads
,
0
,
streamid
>>>
(
grad_input
,
grad
,
output
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
return
true
;
}
return
false
;
}
template
<
typename
input_t
,
typename
output_t
,
typename
acc_t
,
int
WARP_BATCH
,
int
WARP_ITERATIONS
,
int
WARP_SIZE
=
32
,
int
ELEMENTS_PER_LDG_STG
=
1
>
__global__
void
masked_softmax_warp_backward
(
__half
*
gradInput
,
const
__half
*
grad
,
const
__half
*
output
,
const
uint8_t
*
pad_mask
,
int
batch_size
,
int
stride
,
int
element_count
,
int
pad_batch_stride
)
{
int
first_batch
=
(
blockDim
.
y
*
blockIdx
.
x
+
threadIdx
.
y
)
*
WARP_BATCH
;
// batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP.
int
local_batches
=
batch_size
-
first_batch
;
if
(
local_batches
>
WARP_BATCH
)
local_batches
=
WARP_BATCH
;
// there might be multiple batches per warp. compute the index within the
// batch
int
local_idx
=
threadIdx
.
x
;
// the first element to process by the current thread
int
thread_offset
=
first_batch
*
stride
+
ELEMENTS_PER_LDG_STG
*
local_idx
;
grad
+=
thread_offset
;
output
+=
thread_offset
;
gradInput
+=
thread_offset
;
// load data from global memory
input_t
grad_reg_input
[
WARP_BATCH
][
WARP_ITERATIONS
]
=
{
0.0
f
};
input_t
output_reg_input
[
WARP_BATCH
][
WARP_ITERATIONS
]
=
{
0.0
f
};
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
int
batch_element_count
=
(
i
>=
local_batches
)
?
0
:
element_count
;
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
it
+=
ELEMENTS_PER_LDG_STG
)
{
int
element_index
=
ELEMENTS_PER_LDG_STG
*
local_idx
+
it
*
WARP_SIZE
;
if
(
element_index
<
batch_element_count
)
{
copy_vector
<
input_t
,
ELEMENTS_PER_LDG_STG
>
(
&
grad_reg_input
[
i
][
it
],
grad
+
i
*
element_count
+
it
*
WARP_SIZE
);
copy_vector
<
input_t
,
ELEMENTS_PER_LDG_STG
>
(
&
output_reg_input
[
i
][
it
],
output
+
i
*
element_count
+
it
*
WARP_SIZE
);
}
}
}
// convert half to floating point
acc_t
grad_reg
[
WARP_BATCH
][
WARP_ITERATIONS
];
acc_t
output_reg
[
WARP_BATCH
][
WARP_ITERATIONS
];
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
++
it
)
{
grad_reg
[
i
][
it
]
=
grad_reg_input
[
i
][
it
];
output_reg
[
i
][
it
]
=
output_reg_input
[
i
][
it
];
}
}
// compute thread local sum
acc_t
sum
[
WARP_BATCH
]
=
{
0
};
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
++
it
)
{
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
sum
[
i
]
+=
grad_reg
[
i
][
it
]
*
output_reg
[
i
][
it
];
}
}
// reduction sum
constexpr
uint32_t
FULL_MASK
=
0xffffffff
;
#pragma unroll
for
(
int
offset
=
WARP_SIZE
/
2
;
offset
>
0
;
offset
/=
2
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
sum
[
i
]
+=
APEX_WARP_SHFL_XOR
(
FULL_MASK
,
sum
[
i
],
offset
,
WARP_SIZE
);
}
}
// store result
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
if
(
i
>=
local_batches
)
break
;
int
pad_thread_offset
=
((
first_batch
+
i
)
/
pad_batch_stride
)
*
stride
+
ELEMENTS_PER_LDG_STG
*
local_idx
;
const
uint8_t
*
curr_mask
=
pad_mask
+
pad_thread_offset
;
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
it
+=
ELEMENTS_PER_LDG_STG
)
{
int
element_index
=
ELEMENTS_PER_LDG_STG
*
local_idx
+
it
*
WARP_SIZE
;
if
(
element_index
<
element_count
)
{
// compute gradients
output_t
out
[
ELEMENTS_PER_LDG_STG
];
for
(
int
element
=
0
;
element
<
ELEMENTS_PER_LDG_STG
;
++
element
)
{
out
[
element
]
=
(
output_reg
[
i
][
it
+
element
]
*
(
grad_reg
[
i
][
it
+
element
]
-
sum
[
i
]));
}
// store them in global memory
int
itr_jmp
=
it
*
WARP_SIZE
;
int
itr_idx
=
i
*
element_count
+
itr_jmp
;
// It is kind of unfortunate this has to be here to zero something out
// that is close to zero in the first place
apply_mask
<
input_t
,
ELEMENTS_PER_LDG_STG
>
(
&
out
[
0
],
0.0
,
curr_mask
+
itr_jmp
);
copy_vector
<
output_t
,
ELEMENTS_PER_LDG_STG
>
(
gradInput
+
itr_idx
,
out
);
}
}
}
}
// WARP_BATCH number of batches.
// WARP_ITERATOINS The number of iterations required for one warp to iterate
// over all data. WARP_SIZE number of elements working on a single batch, has to
// be a power of two. ELEMENTS_PER_LDG_STG has to be 1.
template
<
typename
input_t
,
typename
output_t
>
using
masked_softmax_backward_func
=
void
(
*
)(
output_t
*
gradInput
,
const
input_t
*
grad
,
const
input_t
*
output
,
const
uint8_t
*
pad_mask
,
int
batch_size
,
int
stride
,
int
element_count
,
int
pad_batch_stride
);
template
<
typename
input_t
,
typename
output_t
,
typename
acc_t
>
bool
warp_masked_softmax_backward_kernel
(
int
log2_elements
,
int
&
warp_size
,
int
&
batches_per_warp
,
masked_softmax_backward_func
<
input_t
,
output_t
>
&
kernel
)
{
// determine size of a warp
const
int
next_power_of_two
=
1
<<
log2_elements
;
warp_size
=
(
next_power_of_two
<
32
)
?
next_power_of_two
:
32
;
// determine how many batches a warp should process.
batches_per_warp
=
(
next_power_of_two
<=
128
)
?
2
:
1
;
switch
(
log2_elements
)
{
case
0
:
// 1
kernel
=
&
masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
2
,
1
,
1
,
1
>
;
break
;
case
1
:
// 2
kernel
=
&
masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
2
,
1
,
2
,
1
>
;
break
;
case
2
:
// 4
kernel
=
&
masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
2
,
1
,
4
,
1
>
;
break
;
case
3
:
// 8
kernel
=
&
masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
2
,
1
,
8
,
1
>
;
break
;
case
4
:
// 16
kernel
=
&
masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
2
,
1
,
16
,
1
>
;
break
;
case
5
:
// 32
kernel
=
&
masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
2
,
1
,
32
,
1
>
;
break
;
case
6
:
// 64
kernel
=
&
masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
2
,
2
,
32
,
1
>
;
break
;
case
7
:
// 128
kernel
=
&
masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
2
,
4
,
32
,
1
>
;
break
;
case
8
:
// 256
kernel
=
&
masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
1
,
8
,
32
,
1
>
;
break
;
case
9
:
// 512
kernel
=
&
masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
1
,
16
,
32
,
1
>
;
break
;
case
10
:
// 1024
kernel
=
&
masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
1
,
32
,
32
,
1
>
;
break
;
default:
return
false
;
}
return
true
;
}
template
<
typename
input_t
,
typename
output_t
,
typename
acc_t
>
bool
dispatch_masked_softmax_backward
(
output_t
*
grad_input
,
const
input_t
*
grad
,
const
input_t
*
output
,
const
uint8_t
*
pad_mask
,
int
softmax_elements
,
int
softmax_elements_stride
,
int
batch_count
,
int
pad_batch_stride
)
{
if
(
softmax_elements
==
0
)
{
return
true
;
}
else
if
(
softmax_elements
<=
1024
)
{
// compute function index. there's a function for each power of two size up
// to 1024.
int
log2_elements
=
0
;
while
((
1
<<
log2_elements
)
<
softmax_elements
)
++
log2_elements
;
masked_softmax_backward_func
<
input_t
,
output_t
>
kernel
;
int
warp_size
,
batches_per_warp
;
if
(
!
warp_masked_softmax_backward_kernel
<
input_t
,
output_t
,
acc_t
>
(
log2_elements
,
warp_size
,
batches_per_warp
,
kernel
))
{
return
false
;
}
// use 128 threads per block to maximimize gpu utilization
constexpr
int
threads_per_block
=
128
;
// compute warps per block.
int
warps_per_block
=
(
threads_per_block
/
warp_size
);
// compute launch size
int
batches_per_block
=
warps_per_block
*
batches_per_warp
;
int
blocks
=
(
batch_count
+
batches_per_block
-
1
)
/
batches_per_block
;
dim3
threads
(
warp_size
,
warps_per_block
,
1
);
// launch
kernel
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
pad_mask
,
batch_count
,
softmax_elements_stride
,
softmax_elements
,
pad_batch_stride
);
return
true
;
}
return
false
;
}
}
// namespace
Prev
1
2
3
4
5
6
7
8
9
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment