Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
ae137ed1
Commit
ae137ed1
authored
Dec 10, 2022
by
Tri Dao
Browse files
[LayerNorm] Fuse LayerScale
parent
8c6609ae
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
312 additions
and
330 deletions
+312
-330
csrc/layer_norm/ln.h
csrc/layer_norm/ln.h
+8
-1
csrc/layer_norm/ln_api.cpp
csrc/layer_norm/ln_api.cpp
+58
-133
csrc/layer_norm/ln_bwd_kernels.cuh
csrc/layer_norm/ln_bwd_kernels.cuh
+118
-57
csrc/layer_norm/ln_fwd_kernels.cuh
csrc/layer_norm/ln_fwd_kernels.cuh
+45
-44
csrc/layer_norm/ln_kernel_traits.h
csrc/layer_norm/ln_kernel_traits.h
+3
-1
csrc/layer_norm/ln_utils.cuh
csrc/layer_norm/ln_utils.cuh
+2
-2
flash_attn/ops/layer_norm.py
flash_attn/ops/layer_norm.py
+47
-80
tests/ops/test_dropout_layer_norm.py
tests/ops/test_dropout_layer_norm.py
+31
-12
No files found.
csrc/layer_norm/ln.h
View file @
ae137ed1
...
@@ -40,6 +40,8 @@ struct ParamsBase {
...
@@ -40,6 +40,8 @@ struct ParamsBase {
,
mu
(
nullptr
)
,
mu
(
nullptr
)
,
rs
(
nullptr
)
,
rs
(
nullptr
)
,
gamma
(
nullptr
)
,
gamma
(
nullptr
)
,
rowscale
(
nullptr
)
,
colscale
(
nullptr
)
,
dropout_keep_p
(
1.
f
)
,
dropout_keep_p
(
1.
f
)
,
dropout_scale
(
1.
f
)
,
dropout_scale
(
1.
f
)
,
workspace
(
nullptr
)
,
workspace
(
nullptr
)
...
@@ -63,6 +65,7 @@ struct ParamsBase {
...
@@ -63,6 +65,7 @@ struct ParamsBase {
void
*
rs
;
void
*
rs
;
void
*
gamma
;
void
*
gamma
;
void
*
rowscale
;
void
*
rowscale
;
void
*
colscale
;
float
inverse_cols
;
float
inverse_cols
;
...
@@ -106,10 +109,12 @@ struct BwdParams : public ParamsBase {
...
@@ -106,10 +109,12 @@ struct BwdParams : public ParamsBase {
,
dx
(
nullptr
)
,
dx
(
nullptr
)
,
dbeta_part
(
nullptr
)
,
dbeta_part
(
nullptr
)
,
dgamma_part
(
nullptr
)
,
dgamma_part
(
nullptr
)
,
dcolscale_part
(
nullptr
)
,
dx0
(
nullptr
)
,
dx0
(
nullptr
)
,
dx1
(
nullptr
)
,
dx1
(
nullptr
)
,
dbeta
(
nullptr
)
,
dbeta
(
nullptr
)
,
dgamma
(
nullptr
)
,
dgamma
(
nullptr
)
,
dcolscale
(
nullptr
)
{
{
}
}
...
@@ -121,6 +126,7 @@ struct BwdParams : public ParamsBase {
...
@@ -121,6 +126,7 @@ struct BwdParams : public ParamsBase {
// Workspace for Wgrad pre-reduction.
// Workspace for Wgrad pre-reduction.
void
*
dbeta_part
;
void
*
dbeta_part
;
void
*
dgamma_part
;
void
*
dgamma_part
;
void
*
dcolscale_part
;
// Output: Dgrad.
// Output: Dgrad.
void
*
dx0
;
void
*
dx0
;
...
@@ -128,13 +134,14 @@ struct BwdParams : public ParamsBase {
...
@@ -128,13 +134,14 @@ struct BwdParams : public ParamsBase {
// Output: Wgrad.
// Output: Wgrad.
void
*
dbeta
;
void
*
dbeta
;
void
*
dgamma
;
void
*
dgamma
;
void
*
dcolscale
;
};
};
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
using
FwdFunction
=
std
::
function
<
void
(
LaunchParams
<
FwdParams
>&
,
const
bool
)
>
;
using
FwdFunction
=
std
::
function
<
void
(
LaunchParams
<
FwdParams
>&
,
const
bool
)
>
;
using
BwdFunction
=
std
::
function
<
void
(
LaunchParams
<
BwdParams
>&
,
const
bool
,
const
bool
)
>
;
using
BwdFunction
=
std
::
function
<
void
(
LaunchParams
<
BwdParams
>&
,
const
bool
)
>
;
using
FunctionKey
=
uint64_t
;
using
FunctionKey
=
uint64_t
;
using
FwdRegistry
=
std
::
unordered_map
<
FunctionKey
,
FwdFunction
>
;
using
FwdRegistry
=
std
::
unordered_map
<
FunctionKey
,
FwdFunction
>
;
using
BwdRegistry
=
std
::
unordered_map
<
FunctionKey
,
BwdFunction
>
;
using
BwdRegistry
=
std
::
unordered_map
<
FunctionKey
,
BwdFunction
>
;
...
...
csrc/layer_norm/ln_api.cpp
View file @
ae137ed1
...
@@ -84,6 +84,7 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input:
...
@@ -84,6 +84,7 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input:
const
at
::
Tensor
&
gamma
,
// hidden_size
const
at
::
Tensor
&
gamma
,
// hidden_size
const
at
::
Tensor
&
beta
,
// hidden_size
const
at
::
Tensor
&
beta
,
// hidden_size
c10
::
optional
<
const
at
::
Tensor
>
&
rowscale_
,
// BxS
c10
::
optional
<
const
at
::
Tensor
>
&
rowscale_
,
// BxS
c10
::
optional
<
const
at
::
Tensor
>
&
colscale_
,
// BxS
const
float
dropout_p
,
const
float
dropout_p
,
const
float
epsilon
,
const
float
epsilon
,
c10
::
optional
<
at
::
Generator
>
gen_
,
c10
::
optional
<
at
::
Generator
>
gen_
,
...
@@ -124,7 +125,15 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input:
...
@@ -124,7 +125,15 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input:
TORCH_CHECK
(
rowscale
.
is_cuda
())
TORCH_CHECK
(
rowscale
.
is_cuda
())
TORCH_CHECK
(
rowscale
.
is_contiguous
());
TORCH_CHECK
(
rowscale
.
is_contiguous
());
TORCH_CHECK
(
rowscale
.
sizes
()
==
std
::
vector
<
int64_t
>
{
rows
});
TORCH_CHECK
(
rowscale
.
sizes
()
==
std
::
vector
<
int64_t
>
{
rows
});
TORCH_CHECK
(
rowscale
.
scalar_type
()
==
itype
);
TORCH_CHECK
(
rowscale
.
dtype
()
==
itype
);
}
if
(
colscale_
.
has_value
())
{
auto
colscale
=
colscale_
.
value
();
TORCH_CHECK
(
colscale
.
is_cuda
())
TORCH_CHECK
(
colscale
.
is_contiguous
());
TORCH_CHECK
(
colscale
.
sizes
()
==
std
::
vector
<
int64_t
>
{
cols
});
TORCH_CHECK
(
colscale
.
dtype
()
==
wtype
);
}
}
TORCH_CHECK
(
gamma
.
sizes
()
==
beta
.
sizes
());
TORCH_CHECK
(
gamma
.
sizes
()
==
beta
.
sizes
());
...
@@ -135,7 +144,7 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input:
...
@@ -135,7 +144,7 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input:
auto
opts
=
x0
.
options
();
auto
opts
=
x0
.
options
();
bool
save_x
=
x1_
.
has_value
()
||
(
dropout_p
>
0.
f
)
||
(
itype
!=
rtype
);
bool
save_x
=
x1_
.
has_value
()
||
(
dropout_p
>
0.
f
)
||
rowscale_
.
has_value
()
||
colscale_
.
has_value
()
||
(
itype
!=
rtype
);
at
::
Tensor
x
;
at
::
Tensor
x
;
if
(
save_x
)
{
x
=
torch
::
empty
(
sizes
,
opts
.
dtype
(
rtype
));
}
if
(
save_x
)
{
x
=
torch
::
empty
(
sizes
,
opts
.
dtype
(
rtype
));
}
at
::
Tensor
dmask
;
at
::
Tensor
dmask
;
...
@@ -153,6 +162,7 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input:
...
@@ -153,6 +162,7 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input:
launch_params
.
params
.
dropout_keep_p
=
1.
f
-
dropout_p
;
launch_params
.
params
.
dropout_keep_p
=
1.
f
-
dropout_p
;
launch_params
.
params
.
x1
=
x1_
.
has_value
()
?
x1_
.
value
().
data_ptr
()
:
nullptr
;
launch_params
.
params
.
x1
=
x1_
.
has_value
()
?
x1_
.
value
().
data_ptr
()
:
nullptr
;
launch_params
.
params
.
rowscale
=
rowscale_
.
has_value
()
?
rowscale_
.
value
().
data_ptr
()
:
nullptr
;
launch_params
.
params
.
rowscale
=
rowscale_
.
has_value
()
?
rowscale_
.
value
().
data_ptr
()
:
nullptr
;
launch_params
.
params
.
colscale
=
colscale_
.
has_value
()
?
colscale_
.
value
().
data_ptr
()
:
nullptr
;
auto
gen
=
at
::
get_generator_or_default
<
at
::
CUDAGeneratorImpl
>
(
auto
gen
=
at
::
get_generator_or_default
<
at
::
CUDAGeneratorImpl
>
(
gen_
,
at
::
cuda
::
detail
::
getDefaultCUDAGenerator
());
gen_
,
at
::
cuda
::
detail
::
getDefaultCUDAGenerator
());
...
@@ -212,12 +222,15 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input:
...
@@ -212,12 +222,15 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input:
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
std
::
vector
<
at
::
Tensor
>
dropout_add_ln_bwd
(
const
at
::
Tensor
&
dz
,
// BxSxhidden_size
std
::
vector
<
at
::
Tensor
>
dropout_add_ln_bwd
(
const
at
::
Tensor
&
dz
,
// BxSxhidden_size
c10
::
optional
<
const
at
::
Tensor
>
&
dx_
,
// BxSxhidden_size
const
at
::
Tensor
&
x
,
// BxSxhidden_size
const
at
::
Tensor
&
x
,
// BxSxhidden_size
c10
::
optional
<
const
at
::
Tensor
>
&
x0_
,
// BxSxhidden_size
c10
::
optional
<
const
at
::
Tensor
>
&
dmask_
,
// BxSxhidden_size
c10
::
optional
<
const
at
::
Tensor
>
&
dmask_
,
// BxSxhidden_size
const
at
::
Tensor
&
mu
,
// BxS, FP32!
const
at
::
Tensor
&
mu
,
// BxS, FP32!
const
at
::
Tensor
&
rsigma
,
// BxS, FP32!
const
at
::
Tensor
&
rsigma
,
// BxS, FP32!
const
at
::
Tensor
&
gamma
,
// hidden_size
const
at
::
Tensor
&
gamma
,
// hidden_size
c10
::
optional
<
const
at
::
Tensor
>
&
rowscale_
,
// BxS
c10
::
optional
<
const
at
::
Tensor
>
&
rowscale_
,
// BxS
c10
::
optional
<
const
at
::
Tensor
>
&
colscale_
,
// BxS
const
float
dropout_p
,
const
float
dropout_p
,
const
bool
has_residual
const
bool
has_residual
)
{
)
{
...
@@ -250,133 +263,14 @@ std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidd
...
@@ -250,133 +263,14 @@ std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidd
auto
rows
=
sizes
[
0
];
auto
rows
=
sizes
[
0
];
auto
cols
=
sizes
[
1
];
auto
cols
=
sizes
[
1
];
if
(
dmask_
.
has_value
())
{
if
(
dx_
.
has_value
())
{
auto
dmask
=
dmask_
.
value
();
auto
dx
=
dx_
.
value
();
TORCH_CHECK
(
dmask
.
dtype
()
==
mtype
);
TORCH_CHECK
(
dx
.
dtype
()
==
rtype
);
TORCH_CHECK
(
dmask
.
is_cuda
());
TORCH_CHECK
(
dx
.
is_cuda
())
TORCH_CHECK
(
dmask
.
is_contiguous
());
TORCH_CHECK
(
dx
.
is_contiguous
());
TORCH_CHECK
(
dmask
.
sizes
()
==
sizes
);
TORCH_CHECK
(
dx
.
sizes
()
==
sizes
);
}
if
(
rowscale_
.
has_value
())
{
auto
rowscale
=
rowscale_
.
value
();
TORCH_CHECK
(
rowscale
.
is_cuda
())
TORCH_CHECK
(
rowscale
.
is_contiguous
());
TORCH_CHECK
(
rowscale
.
sizes
()
==
std
::
vector
<
int64_t
>
{
rows
});
TORCH_CHECK
(
rowscale
.
scalar_type
()
==
itype
);
}
auto
hidden_size
=
gamma
.
numel
();
TORCH_CHECK
(
hidden_size
==
cols
);
TORCH_CHECK
((
hidden_size
%
8
==
0
)
&&
(
hidden_size
<=
6144
));
TORCH_CHECK
(
mu
.
numel
()
==
rows
);
TORCH_CHECK
(
mu
.
sizes
()
==
rsigma
.
sizes
());
TORCH_CHECK
(
gamma
.
numel
()
==
cols
);
auto
opts
=
x
.
options
();
auto
dx0
=
torch
::
empty_like
(
x
,
opts
.
dtype
(
itype
));
at
::
Tensor
dx1
;
if
(
has_residual
)
{
dx1
=
torch
::
empty_like
(
x
,
opts
.
dtype
(
rtype
));
}
auto
dgamma
=
torch
::
empty_like
(
gamma
);
auto
dbeta
=
torch
::
empty_like
(
gamma
);
layer_norm
::
LaunchParams
<
layer_norm
::
BwdParams
>
launch_params
;
launch_params
.
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
launch_params
.
props
=
at
::
cuda
::
getCurrentDeviceProperties
();
TORCH_CHECK
(
dropout_p
<
1.
f
);
launch_params
.
params
.
dropout_keep_p
=
1.
f
-
dropout_p
;
launch_params
.
params
.
dx1
=
has_residual
?
dx1
.
data_ptr
()
:
nullptr
;
launch_params
.
params
.
rowscale
=
rowscale_
.
has_value
()
?
rowscale_
.
value
().
data_ptr
()
:
nullptr
;
auto
round_multiple
=
[](
int
x
,
int
m
)
{
return
(
x
+
m
-
1
)
/
m
*
m
;
};
const
int
multiple
=
hidden_size
<=
1536
?
256
:
(
hidden_size
<=
3072
?
512
:
1024
);
auto
launcher
=
get_bwd_launcher
(
wtype
,
itype
,
rtype
,
otype
,
ctype
,
round_multiple
(
hidden_size
,
multiple
));
launcher
(
launch_params
,
true
,
/*prenorm=*/
false
);
auto
dgamma_part
=
torch
::
empty
({
launch_params
.
params
.
ctas_per_col
,
hidden_size
},
opts
.
dtype
(
ctype
));
auto
dbeta_part
=
torch
::
empty
({
launch_params
.
params
.
ctas_per_col
,
hidden_size
},
opts
.
dtype
(
ctype
));
at
::
Tensor
workspace
,
barrier
;
layer_norm
::
BwdParams
&
params
=
launch_params
.
params
;
params
.
rows
=
rows
;
params
.
cols
=
cols
;
params
.
x
=
x
.
data_ptr
();
params
.
dmask
=
dropout_p
>
0.
f
?
dmask_
.
value
().
data_ptr
()
:
nullptr
;
params
.
mu
=
mu
.
data_ptr
();
params
.
rs
=
rsigma
.
data_ptr
();
params
.
gamma
=
gamma
.
data_ptr
();
params
.
dz
=
dz
.
data_ptr
();
params
.
dx0
=
dx0
.
data_ptr
();
params
.
dbeta
=
dbeta
.
data_ptr
();
params
.
dgamma
=
dgamma
.
data_ptr
();
params
.
dbeta_part
=
dbeta_part
.
data_ptr
();
params
.
dgamma_part
=
dgamma_part
.
data_ptr
();
params
.
dropout_scale
=
1.
f
/
(
1.
f
-
dropout_p
);
params
.
inverse_cols
=
1.
f
/
float
(
params
.
cols
);
if
(
launch_params
.
barrier_size
>
0
)
{
// TODO Any way to avoid this?
barrier
=
torch
::
zeros
(
launch_params
.
barrier_size
,
opts
.
dtype
(
torch
::
kInt32
));
workspace
=
torch
::
empty
(
launch_params
.
workspace_bytes
,
opts
.
dtype
(
torch
::
kChar
));
params
.
workspace
=
workspace
.
data_ptr
();
params
.
barrier
=
barrier
.
data_ptr
<
int
>
();
}
}
launcher
(
launch_params
,
false
,
/*prenorm=*/
false
);
return
{
dx0
,
dx1
,
dgamma
,
dbeta
,
dgamma_part
,
dbeta_part
};
}
////////////////////////////////////////////////////////////////////////////////////////////////////
std
::
vector
<
at
::
Tensor
>
dropout_add_ln_prenorm_bwd
(
const
at
::
Tensor
&
dz
,
// BxSxhidden_size
const
at
::
Tensor
&
dx
,
// BxSxhidden_size
const
at
::
Tensor
&
x
,
// BxSxhidden_size
c10
::
optional
<
const
at
::
Tensor
>
&
dmask_
,
// BxSxhidden_size
const
at
::
Tensor
&
mu
,
// BxS, FP32!
const
at
::
Tensor
&
rsigma
,
// BxS, FP32!
const
at
::
Tensor
&
gamma
,
// hidden_size
c10
::
optional
<
const
at
::
Tensor
>
&
rowscale_
,
// BxS
const
float
dropout_p
,
const
bool
has_residual
)
{
auto
itype
=
dz
.
scalar_type
();
auto
rtype
=
x
.
scalar_type
();
auto
wtype
=
gamma
.
scalar_type
();
auto
otype
=
itype
;
auto
ctype
=
torch
::
kFloat32
;
auto
mtype
=
torch
::
kUInt8
;
if
(
dropout_p
>
0.
f
)
{
TORCH_CHECK
(
dmask_
.
has_value
());
}
TORCH_CHECK
(
dz
.
dtype
()
==
otype
);
TORCH_CHECK
(
dx
.
dtype
()
==
rtype
);
TORCH_CHECK
(
mu
.
dtype
()
==
ctype
);
TORCH_CHECK
(
rsigma
.
dtype
()
==
ctype
);
TORCH_CHECK
(
x
.
is_cuda
());
TORCH_CHECK
(
dz
.
is_cuda
());
TORCH_CHECK
(
dx
.
is_cuda
());
TORCH_CHECK
(
mu
.
is_cuda
());
TORCH_CHECK
(
rsigma
.
is_cuda
());
TORCH_CHECK
(
gamma
.
is_cuda
());
TORCH_CHECK
(
x
.
is_contiguous
());
TORCH_CHECK
(
dz
.
is_contiguous
());
TORCH_CHECK
(
dx
.
is_contiguous
());
auto
sizes
=
x
.
sizes
();
TORCH_CHECK
(
sizes
.
size
()
==
2
);
TORCH_CHECK
(
dz
.
sizes
()
==
sizes
);
TORCH_CHECK
(
dx
.
sizes
()
==
sizes
);
auto
rows
=
sizes
[
0
];
auto
cols
=
sizes
[
1
];
if
(
dmask_
.
has_value
())
{
if
(
dmask_
.
has_value
())
{
auto
dmask
=
dmask_
.
value
();
auto
dmask
=
dmask_
.
value
();
TORCH_CHECK
(
dmask
.
dtype
()
==
mtype
);
TORCH_CHECK
(
dmask
.
dtype
()
==
mtype
);
...
@@ -390,7 +284,22 @@ std::vector<at::Tensor> dropout_add_ln_prenorm_bwd(const at::Tensor &dz, //
...
@@ -390,7 +284,22 @@ std::vector<at::Tensor> dropout_add_ln_prenorm_bwd(const at::Tensor &dz, //
TORCH_CHECK
(
rowscale
.
is_cuda
())
TORCH_CHECK
(
rowscale
.
is_cuda
())
TORCH_CHECK
(
rowscale
.
is_contiguous
());
TORCH_CHECK
(
rowscale
.
is_contiguous
());
TORCH_CHECK
(
rowscale
.
sizes
()
==
std
::
vector
<
int64_t
>
{
rows
});
TORCH_CHECK
(
rowscale
.
sizes
()
==
std
::
vector
<
int64_t
>
{
rows
});
TORCH_CHECK
(
rowscale
.
scalar_type
()
==
itype
);
TORCH_CHECK
(
rowscale
.
dtype
()
==
itype
);
}
if
(
colscale_
.
has_value
())
{
auto
colscale
=
colscale_
.
value
();
TORCH_CHECK
(
colscale
.
is_cuda
())
TORCH_CHECK
(
colscale
.
is_contiguous
());
TORCH_CHECK
(
colscale
.
sizes
()
==
std
::
vector
<
int64_t
>
{
cols
});
TORCH_CHECK
(
colscale
.
dtype
()
==
wtype
);
TORCH_CHECK
(
x0_
.
has_value
());
auto
x0
=
x0_
.
value
();
TORCH_CHECK
(
x0
.
is_cuda
())
TORCH_CHECK
(
x0
.
is_contiguous
());
TORCH_CHECK
(
x0
.
sizes
()
==
sizes
);
TORCH_CHECK
(
x0
.
dtype
()
==
itype
);
}
}
auto
hidden_size
=
gamma
.
numel
();
auto
hidden_size
=
gamma
.
numel
();
...
@@ -409,6 +318,10 @@ std::vector<at::Tensor> dropout_add_ln_prenorm_bwd(const at::Tensor &dz, //
...
@@ -409,6 +318,10 @@ std::vector<at::Tensor> dropout_add_ln_prenorm_bwd(const at::Tensor &dz, //
if
(
has_residual
)
{
dx1
=
torch
::
empty_like
(
x
,
opts
.
dtype
(
rtype
));
}
if
(
has_residual
)
{
dx1
=
torch
::
empty_like
(
x
,
opts
.
dtype
(
rtype
));
}
auto
dgamma
=
torch
::
empty_like
(
gamma
);
auto
dgamma
=
torch
::
empty_like
(
gamma
);
auto
dbeta
=
torch
::
empty_like
(
gamma
);
auto
dbeta
=
torch
::
empty_like
(
gamma
);
at
::
Tensor
dcolscale
;
if
(
colscale_
.
has_value
())
{
dcolscale
=
torch
::
empty_like
(
colscale_
.
value
());
}
layer_norm
::
LaunchParams
<
layer_norm
::
BwdParams
>
launch_params
;
layer_norm
::
LaunchParams
<
layer_norm
::
BwdParams
>
launch_params
;
launch_params
.
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
launch_params
.
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
...
@@ -417,32 +330,40 @@ std::vector<at::Tensor> dropout_add_ln_prenorm_bwd(const at::Tensor &dz, //
...
@@ -417,32 +330,40 @@ std::vector<at::Tensor> dropout_add_ln_prenorm_bwd(const at::Tensor &dz, //
launch_params
.
params
.
dropout_keep_p
=
1.
f
-
dropout_p
;
launch_params
.
params
.
dropout_keep_p
=
1.
f
-
dropout_p
;
launch_params
.
params
.
dx1
=
has_residual
?
dx1
.
data_ptr
()
:
nullptr
;
launch_params
.
params
.
dx1
=
has_residual
?
dx1
.
data_ptr
()
:
nullptr
;
launch_params
.
params
.
rowscale
=
rowscale_
.
has_value
()
?
rowscale_
.
value
().
data_ptr
()
:
nullptr
;
launch_params
.
params
.
rowscale
=
rowscale_
.
has_value
()
?
rowscale_
.
value
().
data_ptr
()
:
nullptr
;
launch_params
.
params
.
colscale
=
colscale_
.
has_value
()
?
colscale_
.
value
().
data_ptr
()
:
nullptr
;
auto
round_multiple
=
[](
int
x
,
int
m
)
{
return
(
x
+
m
-
1
)
/
m
*
m
;
};
auto
round_multiple
=
[](
int
x
,
int
m
)
{
return
(
x
+
m
-
1
)
/
m
*
m
;
};
const
int
multiple
=
hidden_size
<=
1536
?
256
:
(
hidden_size
<=
3072
?
512
:
1024
);
const
int
multiple
=
hidden_size
<=
1536
?
256
:
(
hidden_size
<=
3072
?
512
:
1024
);
auto
launcher
=
get_bwd_launcher
(
wtype
,
itype
,
rtype
,
otype
,
ctype
,
round_multiple
(
hidden_size
,
multiple
));
auto
launcher
=
get_bwd_launcher
(
wtype
,
itype
,
rtype
,
otype
,
ctype
,
round_multiple
(
hidden_size
,
multiple
));
launcher
(
launch_params
,
true
,
/*prenorm=*/
true
);
launcher
(
launch_params
,
true
);
auto
dgamma_part
=
torch
::
empty
({
launch_params
.
params
.
ctas_per_col
,
hidden_size
},
opts
.
dtype
(
ctype
));
auto
dgamma_part
=
torch
::
empty
({
launch_params
.
params
.
ctas_per_col
,
hidden_size
},
opts
.
dtype
(
ctype
));
auto
dbeta_part
=
torch
::
empty
({
launch_params
.
params
.
ctas_per_col
,
hidden_size
},
opts
.
dtype
(
ctype
));
auto
dbeta_part
=
torch
::
empty
({
launch_params
.
params
.
ctas_per_col
,
hidden_size
},
opts
.
dtype
(
ctype
));
at
::
Tensor
dcolscale_part
;
if
(
colscale_
.
has_value
())
{
dcolscale_part
=
torch
::
empty
({
launch_params
.
params
.
ctas_per_col
,
hidden_size
},
opts
.
dtype
(
ctype
));
}
at
::
Tensor
workspace
,
barrier
;
at
::
Tensor
workspace
,
barrier
;
layer_norm
::
BwdParams
&
params
=
launch_params
.
params
;
layer_norm
::
BwdParams
&
params
=
launch_params
.
params
;
params
.
rows
=
rows
;
params
.
rows
=
rows
;
params
.
cols
=
cols
;
params
.
cols
=
cols
;
params
.
x
=
x
.
data_ptr
();
params
.
x
=
x
.
data_ptr
();
params
.
x0
=
x0_
.
has_value
()
?
x0_
.
value
().
data_ptr
()
:
nullptr
;
params
.
dmask
=
dropout_p
>
0.
f
?
dmask_
.
value
().
data_ptr
()
:
nullptr
;
params
.
dmask
=
dropout_p
>
0.
f
?
dmask_
.
value
().
data_ptr
()
:
nullptr
;
params
.
mu
=
mu
.
data_ptr
();
params
.
mu
=
mu
.
data_ptr
();
params
.
rs
=
rsigma
.
data_ptr
();
params
.
rs
=
rsigma
.
data_ptr
();
params
.
gamma
=
gamma
.
data_ptr
();
params
.
gamma
=
gamma
.
data_ptr
();
params
.
dz
=
dz
.
data_ptr
();
params
.
dz
=
dz
.
data_ptr
();
params
.
dx
=
dx
.
data_ptr
();
params
.
dx
=
dx
_
.
has_value
()
?
dx_
.
value
()
.
data_ptr
()
:
nullptr
;
params
.
dx0
=
dx0
.
data_ptr
();
params
.
dx0
=
dx0
.
data_ptr
();
params
.
dbeta
=
dbeta
.
data_ptr
();
params
.
dbeta
=
dbeta
.
data_ptr
();
params
.
dgamma
=
dgamma
.
data_ptr
();
params
.
dgamma
=
dgamma
.
data_ptr
();
params
.
dcolscale
=
colscale_
.
has_value
()
?
dcolscale
.
data_ptr
()
:
nullptr
;
params
.
dbeta_part
=
dbeta_part
.
data_ptr
();
params
.
dbeta_part
=
dbeta_part
.
data_ptr
();
params
.
dgamma_part
=
dgamma_part
.
data_ptr
();
params
.
dgamma_part
=
dgamma_part
.
data_ptr
();
params
.
dcolscale_part
=
colscale_
.
has_value
()
?
dcolscale_part
.
data_ptr
()
:
nullptr
;
params
.
dropout_scale
=
1.
f
/
(
1.
f
-
dropout_p
);
params
.
dropout_scale
=
1.
f
/
(
1.
f
-
dropout_p
);
params
.
inverse_cols
=
1.
f
/
float
(
params
.
cols
);
params
.
inverse_cols
=
1.
f
/
float
(
params
.
cols
);
...
@@ -454,9 +375,14 @@ std::vector<at::Tensor> dropout_add_ln_prenorm_bwd(const at::Tensor &dz, //
...
@@ -454,9 +375,14 @@ std::vector<at::Tensor> dropout_add_ln_prenorm_bwd(const at::Tensor &dz, //
params
.
barrier
=
barrier
.
data_ptr
<
int
>
();
params
.
barrier
=
barrier
.
data_ptr
<
int
>
();
}
}
launcher
(
launch_params
,
false
,
/*prenorm=*/
true
);
launcher
(
launch_params
,
false
);
return
{
dx0
,
dx1
,
dgamma
,
dbeta
,
dgamma_part
,
dbeta_part
};
std
::
vector
<
at
::
Tensor
>
result
=
{
dx0
,
dx1
,
dgamma
,
dbeta
,
dgamma_part
,
dbeta_part
};
if
(
colscale_
.
has_value
())
{
result
.
push_back
(
dcolscale
);
result
.
push_back
(
dcolscale_part
);
}
return
result
;
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
...
@@ -464,5 +390,4 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
...
@@ -464,5 +390,4 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m
.
doc
()
=
"CUDA DropoutAddLayerNorm"
;
m
.
doc
()
=
"CUDA DropoutAddLayerNorm"
;
m
.
def
(
"dropout_add_ln_fwd"
,
&
dropout_add_ln_fwd
,
"Run Dropout + Add + LayerNorm forward kernel"
);
m
.
def
(
"dropout_add_ln_fwd"
,
&
dropout_add_ln_fwd
,
"Run Dropout + Add + LayerNorm forward kernel"
);
m
.
def
(
"dropout_add_ln_bwd"
,
&
dropout_add_ln_bwd
,
"Run Dropout + Add + LayerNorm backward kernel"
);
m
.
def
(
"dropout_add_ln_bwd"
,
&
dropout_add_ln_bwd
,
"Run Dropout + Add + LayerNorm backward kernel"
);
m
.
def
(
"dropout_add_ln_prenorm_bwd"
,
&
dropout_add_ln_prenorm_bwd
,
"Run Dropout + Add + LayerNorm (PreNorm version) backward kernel"
);
}
}
csrc/layer_norm/ln_bwd_kernels.cuh
View file @
ae137ed1
...
@@ -7,7 +7,7 @@
...
@@ -7,7 +7,7 @@
namespace
layer_norm
{
namespace
layer_norm
{
template
<
typename
Ktraits
,
bool
Prenorm
,
bool
Is_dropout
,
bool
Has_residual
,
bool
Is_even_cols
>
template
<
typename
Ktraits
,
bool
Prenorm
,
bool
Is_dropout
,
bool
Has_residual
,
bool
Has_colscale
,
bool
Is_even_cols
>
__global__
__launch_bounds__
(
Ktraits
::
THREADS_PER_CTA
)
__global__
__launch_bounds__
(
Ktraits
::
THREADS_PER_CTA
)
void
ln_bwd_kernel
(
layer_norm
::
BwdParams
params
)
{
void
ln_bwd_kernel
(
layer_norm
::
BwdParams
params
)
{
...
@@ -53,9 +53,11 @@ void ln_bwd_kernel(layer_norm::BwdParams params) {
...
@@ -53,9 +53,11 @@ void ln_bwd_kernel(layer_norm::BwdParams params) {
Cvec
dzy_sum
[
LDGS
];
Cvec
dzy_sum
[
LDGS
];
Cvec
dz_sum
[
LDGS
];
Cvec
dz_sum
[
LDGS
];
Cvec
dcolscale_sum
[
LDGS
];
memset
(
dzy_sum
,
0
,
sizeof
(
dzy_sum
));
memset
(
dzy_sum
,
0
,
sizeof
(
dzy_sum
));
memset
(
dz_sum
,
0
,
sizeof
(
dz_sum
));
memset
(
dz_sum
,
0
,
sizeof
(
dz_sum
));
if
(
Has_colscale
)
{
memset
(
dcolscale_sum
,
0
,
sizeof
(
dcolscale_sum
));
}
compute_t
*
smem_wgrad
=
reinterpret_cast
<
compute_t
*>
(
smem_
);
compute_t
*
smem_wgrad
=
reinterpret_cast
<
compute_t
*>
(
smem_
);
char
*
smem_dgrad
=
smem_
+
Ktraits
::
SMEM_BYTES_WGRAD
;
char
*
smem_dgrad
=
smem_
+
Ktraits
::
SMEM_BYTES_WGRAD
;
...
@@ -68,11 +70,13 @@ void ln_bwd_kernel(layer_norm::BwdParams params) {
...
@@ -68,11 +70,13 @@ void ln_bwd_kernel(layer_norm::BwdParams params) {
((
params
.
cols
/
Ktraits
::
ELTS_PER_LDG
)
-
1
-
c
+
Ktraits
::
VEC_COLS_PER_LDG
)
/
Ktraits
::
VEC_COLS_PER_LDG
;
((
params
.
cols
/
Ktraits
::
ELTS_PER_LDG
)
-
1
-
c
+
Ktraits
::
VEC_COLS_PER_LDG
)
/
Ktraits
::
VEC_COLS_PER_LDG
;
Wvec
gamma
[
LDGS
];
Wvec
gamma
[
LDGS
];
Wvec
colscale
[
LDGS
];
index_t
idx
=
c
;
index_t
idx
=
c
;
#pragma unroll
#pragma unroll
for
(
int
it
=
0
;
it
<
LDGS
;
it
++
)
{
for
(
int
it
=
0
;
it
<
LDGS
;
it
++
)
{
if
(
Is_even_cols
||
(
it
<
num_valid_ldgs
))
{
if
(
Is_even_cols
||
(
it
<
num_valid_ldgs
))
{
gamma
[
it
].
load_from
(
params
.
gamma
,
idx
);
gamma
[
it
].
load_from
(
params
.
gamma
,
idx
);
if
(
Has_colscale
)
{
colscale
[
it
].
load_from
(
params
.
colscale
,
idx
);
}
idx
+=
Ktraits
::
VEC_COLS_PER_LDG
;
idx
+=
Ktraits
::
VEC_COLS_PER_LDG
;
}
}
}
}
...
@@ -131,6 +135,8 @@ void ln_bwd_kernel(layer_norm::BwdParams params) {
...
@@ -131,6 +135,8 @@ void ln_bwd_kernel(layer_norm::BwdParams params) {
if
(
Is_even_cols
||
(
it
<
num_valid_ldgs
))
{
if
(
Is_even_cols
||
(
it
<
num_valid_ldgs
))
{
Ivec
dx0
;
Ivec
dx0
;
Rvec
dx1
;
Rvec
dx1
;
Ivec
x0
;
if
(
Has_colscale
)
{
x0
.
load_from
(
params
.
x0
,
idx
);
}
#pragma unroll
#pragma unroll
for
(
int
jt
=
0
;
jt
<
NUM_ELTS
;
jt
++
)
{
for
(
int
jt
=
0
;
jt
<
NUM_ELTS
;
jt
++
)
{
compute_t
dy_tmp
=
dy
[
it
*
NUM_ELTS
+
jt
];
compute_t
dy_tmp
=
dy
[
it
*
NUM_ELTS
+
jt
];
...
@@ -140,9 +146,20 @@ void ln_bwd_kernel(layer_norm::BwdParams params) {
...
@@ -140,9 +146,20 @@ void ln_bwd_kernel(layer_norm::BwdParams params) {
if
(
Has_residual
)
{
dx1
.
data
.
elt
[
jt
]
=
dx_tmp_res
;
}
if
(
Has_residual
)
{
dx1
.
data
.
elt
[
jt
]
=
dx_tmp_res
;
}
compute_t
dx0_tmp_res
=
dx_tmp_res
*
rowscale_val
;
compute_t
dx0_tmp_res
=
dx_tmp_res
*
rowscale_val
;
if
(
Is_dropout
)
{
if
(
Is_dropout
)
{
dx0
.
data
.
elt
[
jt
]
=
dmask
[
it
].
data
.
elt
[
jt
]
?
dx0_tmp_res
*
params
.
dropout_scale
:
0.
f
;
dx0_tmp_res
*=
params
.
dropout_scale
;
if
(
Has_colscale
)
{
dcolscale_sum
[
it
].
data
.
elt
[
jt
]
+=
dmask
[
it
].
data
.
elt
[
jt
]
?
dx0_tmp_res
*
compute_t
(
x0
.
data
.
elt
[
jt
])
:
0.
f
;
dx0
.
data
.
elt
[
jt
]
=
dmask
[
it
].
data
.
elt
[
jt
]
?
dx0_tmp_res
*
compute_t
(
colscale
[
it
].
data
.
elt
[
jt
])
:
0.
f
;
}
else
{
dx0
.
data
.
elt
[
jt
]
=
dmask
[
it
].
data
.
elt
[
jt
]
?
dx0_tmp_res
:
0.
f
;
}
}
else
{
}
else
{
dx0
.
data
.
elt
[
jt
]
=
dx0_tmp_res
;
if
(
Has_colscale
)
{
dcolscale_sum
[
it
].
data
.
elt
[
jt
]
+=
dx0_tmp_res
*
compute_t
(
x0
.
data
.
elt
[
jt
]);
dx0
.
data
.
elt
[
jt
]
=
dx0_tmp_res
*
compute_t
(
colscale
[
it
].
data
.
elt
[
jt
]);
}
else
{
dx0
.
data
.
elt
[
jt
]
=
dx0_tmp_res
;
}
}
}
}
}
if
(
Has_residual
)
{
dx1
.
store_to
(
params
.
dx1
,
idx
);
}
if
(
Has_residual
)
{
dx1
.
store_to
(
params
.
dx1
,
idx
);
}
...
@@ -160,6 +177,7 @@ void ln_bwd_kernel(layer_norm::BwdParams params) {
...
@@ -160,6 +177,7 @@ void ln_bwd_kernel(layer_norm::BwdParams params) {
if
(
Is_even_cols
||
(
it
<
num_valid_ldgs
))
{
if
(
Is_even_cols
||
(
it
<
num_valid_ldgs
))
{
dz_sum
[
it
].
store_to
(
params
.
dbeta_part
,
idx
);
dz_sum
[
it
].
store_to
(
params
.
dbeta_part
,
idx
);
dzy_sum
[
it
].
store_to
(
params
.
dgamma_part
,
idx
);
dzy_sum
[
it
].
store_to
(
params
.
dgamma_part
,
idx
);
if
(
Has_colscale
)
{
dcolscale_sum
[
it
].
store_to
(
params
.
dcolscale_part
,
idx
);
}
idx
+=
Ktraits
::
VEC_COLS_PER_LDG
;
idx
+=
Ktraits
::
VEC_COLS_PER_LDG
;
}
}
}
}
...
@@ -203,23 +221,46 @@ void ln_bwd_kernel(layer_norm::BwdParams params) {
...
@@ -203,23 +221,46 @@ void ln_bwd_kernel(layer_norm::BwdParams params) {
}
}
}
}
compute_t
cta_dcolscale_sum
[
NUM_RES
];
if
(
Has_colscale
)
{
__syncthreads
();
idx
=
warp_m
*
Ktraits
::
VEC_COLS
+
tid_r
;
#pragma unroll
for
(
int
it
=
0
;
it
<
LDGS
;
it
++
)
{
dcolscale_sum
[
it
].
store_to
(
smem_wgrad
,
idx
);
idx
+=
THREADS_PER_ROW
;
}
__syncthreads
();
memset
(
cta_dcolscale_sum
,
0
,
sizeof
(
compute_t
)
*
NUM_RES
);
for
(
int
it
=
0
;
it
<
ROWS_PER_CTA
;
it
++
)
{
for
(
int
jt
=
0
;
jt
<
NUM_RES
;
jt
++
)
{
cta_dcolscale_sum
[
jt
]
+=
smem_wgrad
[
it
*
COLS
+
tidx
+
jt
*
Ktraits
::
THREADS_PER_CTA
];
}
}
}
const
index_t
num_valid_writes
const
index_t
num_valid_writes
=
(
params
.
cols
-
1
-
tidx
+
Ktraits
::
THREADS_PER_CTA
)
/
Ktraits
::
THREADS_PER_CTA
;
=
(
params
.
cols
-
1
-
tidx
+
Ktraits
::
THREADS_PER_CTA
)
/
Ktraits
::
THREADS_PER_CTA
;
compute_t
*
dgamma_part
=
static_cast
<
compute_t
*>
(
params
.
dgamma_part
)
+
bidm
*
params
.
cols
+
tidx
;
compute_t
*
dgamma_part
=
static_cast
<
compute_t
*>
(
params
.
dgamma_part
)
+
bidm
*
params
.
cols
+
tidx
;
compute_t
*
dbeta_part
=
static_cast
<
compute_t
*>
(
params
.
dbeta_part
)
+
bidm
*
params
.
cols
+
tidx
;
compute_t
*
dbeta_part
=
static_cast
<
compute_t
*>
(
params
.
dbeta_part
)
+
bidm
*
params
.
cols
+
tidx
;
compute_t
*
dcolscale_part
=
Has_colscale
?
static_cast
<
compute_t
*>
(
params
.
dcolscale_part
)
+
bidm
*
params
.
cols
+
tidx
:
nullptr
;
for
(
int
jt
=
0
;
jt
<
NUM_RES
;
jt
++
)
{
for
(
int
jt
=
0
;
jt
<
NUM_RES
;
jt
++
)
{
if
(
Is_even_cols
||
(
jt
<
num_valid_writes
))
{
if
(
Is_even_cols
||
(
jt
<
num_valid_writes
))
{
*
dgamma_part
=
cta_dzy_sum
[
jt
];
*
dgamma_part
=
cta_dzy_sum
[
jt
];
dgamma_part
+=
Ktraits
::
THREADS_PER_CTA
;
dgamma_part
+=
Ktraits
::
THREADS_PER_CTA
;
*
dbeta_part
=
cta_dz_sum
[
jt
];
*
dbeta_part
=
cta_dz_sum
[
jt
];
dbeta_part
+=
Ktraits
::
THREADS_PER_CTA
;
dbeta_part
+=
Ktraits
::
THREADS_PER_CTA
;
if
(
Has_colscale
)
{
*
dcolscale_part
=
cta_dcolscale_sum
[
jt
];
dcolscale_part
+=
Ktraits
::
THREADS_PER_CTA
;
}
}
}
}
}
}
}
}
}
template
<
typename
Kernel_traits
,
bool
Is_even_cols
>
template
<
typename
Kernel_traits
,
bool
Has_colscale
,
bool
Is_even_cols
>
__global__
__launch_bounds__
(
Kernel_traits
::
THREADS_PER_CTA
)
__global__
__launch_bounds__
(
Kernel_traits
::
THREADS_PER_CTA
)
void
ln_bwd_finalize_kernel
(
BwdParams
params
)
void
ln_bwd_finalize_kernel
(
BwdParams
params
)
{
{
...
@@ -250,26 +291,29 @@ void ln_bwd_finalize_kernel(BwdParams params)
...
@@ -250,26 +291,29 @@ void ln_bwd_finalize_kernel(BwdParams params)
constexpr
uint32_t
COL_STRIDE
=
Kernel_traits
::
CTAS
*
THREADS_PER_WARP
;
constexpr
uint32_t
COL_STRIDE
=
Kernel_traits
::
CTAS
*
THREADS_PER_WARP
;
for
(
uint32_t
col
=
c
,
col_out
=
c_out
;
col
<
Kernel_traits
::
COLS
;
col
+=
COL_STRIDE
,
col_out
+=
COL_STRIDE
/
2
)
{
for
(
uint32_t
col
=
c
,
col_out
=
c_out
;
col
<
Kernel_traits
::
COLS
;
col
+=
COL_STRIDE
,
col_out
+=
COL_STRIDE
/
2
)
{
// Each thread sums over NUM_ELT columns.
// Each thread sums over NUM_ELT columns.
Vec
<
compute_t
,
NUM_ELT
>
dbeta_local
,
dgamma_local
;
Vec
<
compute_t
,
NUM_ELT
>
dbeta_local
,
dgamma_local
,
dcolscale_local
;
memset
(
&
dgamma_local
,
0
,
sizeof
(
dgamma_local
));
memset
(
&
dgamma_local
,
0
,
sizeof
(
dgamma_local
));
memset
(
&
dbeta_local
,
0
,
sizeof
(
dbeta_local
));
memset
(
&
dbeta_local
,
0
,
sizeof
(
dbeta_local
));
if
(
Has_colscale
)
{
memset
(
&
dcolscale_local
,
0
,
sizeof
(
dcolscale_local
));
}
if
(
Is_even_cols
||
col
<
params
.
cols
)
{
if
(
Is_even_cols
||
col
<
params
.
cols
)
{
for
(
uint32_t
row
=
warp
;
row
<
params
.
ctas_per_col
;
row
+=
Kernel_traits
::
ROWS_PER_CTA
)
{
for
(
uint32_t
row
=
warp
;
row
<
params
.
ctas_per_col
;
row
+=
Kernel_traits
::
ROWS_PER_CTA
)
{
// index_t idx = row * Kernel_traits::COLS + col;
index_t
idx
=
row
*
params
.
cols
+
col
;
index_t
idx
=
row
*
params
.
cols
+
col
;
Vec
<
compute_t
,
NUM_ELT
>
dbeta_part
,
dgamma_part
;
Vec
<
compute_t
,
NUM_ELT
>
dbeta_part
,
dgamma_part
,
dcolscale_part
;
dbeta_part
.
load_from
(
params
.
dbeta_part
,
idx
);
dbeta_part
.
load_from
(
params
.
dbeta_part
,
idx
);
dgamma_part
.
load_from
(
params
.
dgamma_part
,
idx
);
dgamma_part
.
load_from
(
params
.
dgamma_part
,
idx
);
if
(
Has_colscale
)
{
dcolscale_part
.
load_from
(
params
.
dcolscale_part
,
idx
);
}
#pragma unroll
#pragma unroll
for
(
int
it
=
0
;
it
<
NUM_ELT
;
it
++
)
{
for
(
int
it
=
0
;
it
<
NUM_ELT
;
it
++
)
{
dgamma_local
.
data
.
elt
[
it
]
+=
dgamma_part
.
data
.
elt
[
it
];
dgamma_local
.
data
.
elt
[
it
]
+=
dgamma_part
.
data
.
elt
[
it
];
dbeta_local
.
data
.
elt
[
it
]
+=
dbeta_part
.
data
.
elt
[
it
];
dbeta_local
.
data
.
elt
[
it
]
+=
dbeta_part
.
data
.
elt
[
it
];
if
(
Has_colscale
)
{
dcolscale_local
.
data
.
elt
[
it
]
+=
dcolscale_part
.
data
.
elt
[
it
];
}
}
}
}
}
}
}
void
*
smem_gamma
=
smem_
;
void
*
smem_gamma
=
smem_
;
void
*
smem_beta
=
&
smem_
[
Kernel_traits
::
SMEM_BYTES_TRANSPOSE
];
void
*
smem_beta
=
&
smem_
[
Kernel_traits
::
SMEM_BYTES_TRANSPOSE
];
void
*
smem_colscale
=
&
smem_
[
2
*
Kernel_traits
::
SMEM_BYTES_TRANSPOSE
];
const
int
write_row
=
warp
;
const
int
write_row
=
warp
;
const
int
write_col
=
lane
^
write_row
;
const
int
write_col
=
lane
^
write_row
;
...
@@ -277,12 +321,14 @@ void ln_bwd_finalize_kernel(BwdParams params)
...
@@ -277,12 +321,14 @@ void ln_bwd_finalize_kernel(BwdParams params)
dgamma_local
.
store_to
(
smem_gamma
,
write_idx
);
dgamma_local
.
store_to
(
smem_gamma
,
write_idx
);
dbeta_local
.
store_to
(
smem_beta
,
write_idx
);
dbeta_local
.
store_to
(
smem_beta
,
write_idx
);
if
(
Has_colscale
)
{
dcolscale_local
.
store_to
(
smem_colscale
,
write_idx
);
}
__syncthreads
();
__syncthreads
();
// It would be probably safe to reuse the first row of smem_beta and smem_gamma
// It would be probably safe to reuse the first row of smem_beta and smem_gamma
void
*
smem_gamma_out
=
&
smem_
[
2
*
Kernel_traits
::
SMEM_BYTES_TRANSPOSE
];
void
*
smem_gamma_out
=
&
smem_
[
Kernel_traits
::
NUM_FACTORS
*
Kernel_traits
::
SMEM_BYTES_TRANSPOSE
];
void
*
smem_beta_out
=
&
smem_
[
2
*
Kernel_traits
::
SMEM_BYTES_TRANSPOSE
+
Kernel_traits
::
SMEM_BYTES_OUTPUT
];
void
*
smem_beta_out
=
&
smem_
[
Kernel_traits
::
NUM_FACTORS
*
Kernel_traits
::
SMEM_BYTES_TRANSPOSE
+
Kernel_traits
::
SMEM_BYTES_OUTPUT
];
void
*
smem_colscale_out
=
&
smem_
[
Kernel_traits
::
NUM_FACTORS
*
Kernel_traits
::
SMEM_BYTES_TRANSPOSE
+
2
*
Kernel_traits
::
SMEM_BYTES_OUTPUT
];
// More than one iter iff ROWS_PER_CTA < 32.
// More than one iter iff ROWS_PER_CTA < 32.
...
@@ -293,11 +339,13 @@ void ln_bwd_finalize_kernel(BwdParams params)
...
@@ -293,11 +339,13 @@ void ln_bwd_finalize_kernel(BwdParams params)
memset
(
&
dbeta_local
,
0
,
sizeof
(
dbeta_local
));
memset
(
&
dbeta_local
,
0
,
sizeof
(
dbeta_local
));
memset
(
&
dgamma_local
,
0
,
sizeof
(
dgamma_local
));
memset
(
&
dgamma_local
,
0
,
sizeof
(
dgamma_local
));
if
(
Has_colscale
)
{
memset
(
&
dcolscale_local
,
0
,
sizeof
(
dcolscale_local
));
}
// Load beta and gamma transposed
// Load beta and gamma transposed
if
(
read_row
<
Kernel_traits
::
ROWS_PER_CTA
){
if
(
read_row
<
Kernel_traits
::
ROWS_PER_CTA
){
dbeta_local
.
load_from
(
smem_beta
,
read_idx
);
dbeta_local
.
load_from
(
smem_beta
,
read_idx
);
dgamma_local
.
load_from
(
smem_gamma
,
read_idx
);
dgamma_local
.
load_from
(
smem_gamma
,
read_idx
);
if
(
Has_colscale
)
{
dcolscale_local
.
load_from
(
smem_colscale
,
read_idx
);
}
}
}
// Call reducer on the loaded value(s) and convert.
// Call reducer on the loaded value(s) and convert.
...
@@ -310,12 +358,18 @@ void ln_bwd_finalize_kernel(BwdParams params)
...
@@ -310,12 +358,18 @@ void ln_bwd_finalize_kernel(BwdParams params)
dgamma_local
.
data
.
elt
[
it
]
=
g_i
;
dgamma_local
.
data
.
elt
[
it
]
=
g_i
;
dbeta_local
.
data
.
elt
[
it
]
=
b_i
;
dbeta_local
.
data
.
elt
[
it
]
=
b_i
;
if
(
Has_colscale
)
{
compute_t
cs_i
=
dcolscale_local
.
data
.
elt
[
it
];
cs_i
=
reducer
.
allreduce
(
cs_i
,
sum
);
dcolscale_local
.
data
.
elt
[
it
]
=
cs_i
;
}
}
}
// Leader stores the result at the current column.
// Leader stores the result at the current column.
if
(
lane
==
0
){
if
(
lane
==
0
){
dgamma_local
.
store_to
(
smem_gamma_out
,
w
);
dgamma_local
.
store_to
(
smem_gamma_out
,
w
);
dbeta_local
.
store_to
(
smem_beta_out
,
w
);
dbeta_local
.
store_to
(
smem_beta_out
,
w
);
if
(
Has_colscale
)
{
dcolscale_local
.
store_to
(
smem_colscale_out
,
w
);
}
}
}
}
}
...
@@ -329,19 +383,21 @@ void ln_bwd_finalize_kernel(BwdParams params)
...
@@ -329,19 +383,21 @@ void ln_bwd_finalize_kernel(BwdParams params)
using
src_t
=
typename
TypeToVec2
<
compute_t
>::
Type
;
using
src_t
=
typename
TypeToVec2
<
compute_t
>::
Type
;
using
dst_t
=
typename
TypeToVec2
<
weight_t
>::
Type
;
using
dst_t
=
typename
TypeToVec2
<
weight_t
>::
Type
;
Vec
<
src_t
,
NUM_ELT
>
dbeta_vec2
,
dgamma_vec2
;
Vec
<
src_t
,
NUM_ELT
>
dbeta_vec2
,
dgamma_vec2
,
dcolscale_vec2
;
Vec
<
dst_t
,
NUM_ELT
>
dbeta_out2
,
dgamma_out2
;
Vec
<
dst_t
,
NUM_ELT
>
dbeta_out2
,
dgamma_out2
,
dcolscale_out2
;
dgamma_vec2
.
load_from
(
smem_gamma_out
,
lane
);
dgamma_vec2
.
load_from
(
smem_gamma_out
,
lane
);
dbeta_vec2
.
load_from
(
smem_beta_out
,
lane
);
dbeta_vec2
.
load_from
(
smem_beta_out
,
lane
);
if
(
Has_colscale
)
{
dcolscale_vec2
.
load_from
(
smem_colscale_out
,
lane
);
}
#pragma unroll
#pragma unroll
for
(
int
it
=
0
;
it
<
NUM_ELT
;
it
++
)
{
for
(
int
it
=
0
;
it
<
NUM_ELT
;
it
++
)
{
dgamma_out2
.
data
.
elt
[
it
]
=
Converter
<
src_t
,
dst_t
>::
convert
(
dgamma_vec2
.
data
.
elt
[
it
]);
dgamma_out2
.
data
.
elt
[
it
]
=
Converter
<
src_t
,
dst_t
>::
convert
(
dgamma_vec2
.
data
.
elt
[
it
]);
dbeta_out2
.
data
.
elt
[
it
]
=
Converter
<
src_t
,
dst_t
>::
convert
(
dbeta_vec2
.
data
.
elt
[
it
]);
dbeta_out2
.
data
.
elt
[
it
]
=
Converter
<
src_t
,
dst_t
>::
convert
(
dbeta_vec2
.
data
.
elt
[
it
]);
if
(
Has_colscale
)
{
dcolscale_out2
.
data
.
elt
[
it
]
=
Converter
<
src_t
,
dst_t
>::
convert
(
dcolscale_vec2
.
data
.
elt
[
it
]);
}
}
}
dgamma_out2
.
store_to
(
params
.
dgamma
,
col_out
);
dgamma_out2
.
store_to
(
params
.
dgamma
,
col_out
);
dbeta_out2
.
store_to
(
params
.
dbeta
,
col_out
);
dbeta_out2
.
store_to
(
params
.
dbeta
,
col_out
);
if
(
Has_colscale
)
{
dcolscale_out2
.
store_to
(
params
.
dcolscale
,
col_out
);
}
}
}
}
}
}
}
...
@@ -364,7 +420,7 @@ template<
...
@@ -364,7 +420,7 @@ template<
int
BYTES_PER_LDG_MAIN
,
int
BYTES_PER_LDG_MAIN
,
int
BYTES_PER_LDG_FINAL
int
BYTES_PER_LDG_FINAL
>
>
void
launch_
(
LaunchParams
<
BwdParams
>
&
launch_params
,
const
bool
configure_params
,
const
bool
prenorm
){
void
launch_
(
LaunchParams
<
BwdParams
>
&
launch_params
,
const
bool
configure_params
){
using
Kernel_traits
=
Kernel_traits
<
weight_t
,
using
Kernel_traits
=
Kernel_traits
<
weight_t
,
input_t
,
input_t
,
...
@@ -378,59 +434,64 @@ void launch_(LaunchParams<BwdParams> &launch_params, const bool configure_params
...
@@ -378,59 +434,64 @@ void launch_(LaunchParams<BwdParams> &launch_params, const bool configure_params
WARPS_N
,
WARPS_N
,
BYTES_PER_LDG_MAIN
BYTES_PER_LDG_MAIN
>
;
>
;
bool
prenorm
=
launch_params
.
params
.
dx
!=
nullptr
;
bool
is_dropout
=
launch_params
.
params
.
dropout_keep_p
<
1.
f
;
bool
is_dropout
=
launch_params
.
params
.
dropout_keep_p
<
1.
f
;
bool
has_residual
=
launch_params
.
params
.
dx1
!=
nullptr
;
bool
has_residual
=
launch_params
.
params
.
dx1
!=
nullptr
;
bool
has_colscale
=
launch_params
.
params
.
colscale
!=
nullptr
;
bool
is_even_cols
=
launch_params
.
params
.
cols
==
HIDDEN_SIZE
;
bool
is_even_cols
=
launch_params
.
params
.
cols
==
HIDDEN_SIZE
;
BOOL_SWITCH
(
prenorm
,
PrenormConst
,
[
&
]
{
BOOL_SWITCH
(
prenorm
,
PrenormConst
,
[
&
]
{
BOOL_SWITCH
(
is_dropout
,
IsDropoutConst
,
[
&
]
{
BOOL_SWITCH
(
is_dropout
,
IsDropoutConst
,
[
&
]
{
BOOL_SWITCH
(
has_residual
,
HasResidualConst
,
[
&
]
{
BOOL_SWITCH
(
has_residual
,
HasResidualConst
,
[
&
]
{
BOOL_SWITCH
(
is_even_cols
,
IsEvenColsConst
,
[
&
]
{
BOOL_SWITCH
(
has_colscale
,
HasColscaleConst
,
[
&
]
{
auto
kernel
=
&
ln_bwd_kernel
<
Kernel_traits
,
PrenormConst
,
IsDropoutConst
,
HasResidualConst
,
IsEvenColsConst
>
;
BOOL_SWITCH
(
is_even_cols
,
IsEvenColsConst
,
[
&
]
{
if
(
configure_params
)
{
auto
kernel
=
&
ln_bwd_kernel
<
Kernel_traits
,
PrenormConst
,
IsDropoutConst
,
HasResidualConst
,
HasColscaleConst
,
IsEvenColsConst
>
;
int
ctas_per_sm
;
if
(
configure_params
)
{
CHECK_CUDA
(
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
int
ctas_per_sm
;
&
ctas_per_sm
,
kernel
,
Kernel_traits
::
THREADS_PER_CTA
,
Kernel_traits
::
SMEM_BYTES
));
CHECK_CUDA
(
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
launch_params
.
params
.
ctas_per_col
=
launch_params
.
props
->
multiProcessorCount
*
ctas_per_sm
/
Kernel_traits
::
CTAS_PER_ROW
;
&
ctas_per_sm
,
kernel
,
Kernel_traits
::
THREADS_PER_CTA
,
Kernel_traits
::
SMEM_BYTES
));
launch_params
.
barrier_size
=
0
;
launch_params
.
params
.
ctas_per_col
=
launch_params
.
props
->
multiProcessorCount
*
ctas_per_sm
/
Kernel_traits
::
CTAS_PER_ROW
;
launch_params
.
workspace_bytes
=
0
;
launch_params
.
barrier_size
=
0
;
if
(
Kernel_traits
::
CTAS_PER_ROW
>
1
)
{
launch_params
.
workspace_bytes
=
0
;
launch_params
.
barrier_size
=
2
*
launch_params
.
params
.
ctas_per_col
;
if
(
Kernel_traits
::
CTAS_PER_ROW
>
1
)
{
launch_params
.
workspace_bytes
=
launch_params
.
params
.
ctas_per_col
launch_params
.
barrier_size
=
2
*
launch_params
.
params
.
ctas_per_col
;
*
Kernel_traits
::
WARPS_M
launch_params
.
workspace_bytes
=
launch_params
.
params
.
ctas_per_col
*
Kernel_traits
::
CTAS_PER_ROW
*
Kernel_traits
::
WARPS_M
*
sizeof
(
typename
Kernel_traits
::
reduce_t
)
*
Kernel_traits
::
CTAS_PER_ROW
*
2
;
*
sizeof
(
typename
Kernel_traits
::
reduce_t
)
*
2
;
}
return
;
}
}
return
;
}
if
(
Kernel_traits
::
SMEM_BYTES
>=
48
*
1024
)
{
CHECK_CUDA
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
Kernel_traits
::
SMEM_BYTES
));
}
auto
stream
=
launch_params
.
stream
;
auto
ctas_per_col
=
launch_params
.
params
.
ctas_per_col
;
if
(
Kernel_traits
::
CTAS_PER_ROW
==
1
)
{
if
(
Kernel_traits
::
SMEM_BYTES
>=
48
*
1024
)
{
kernel
<<<
ctas_per_col
,
Kernel_traits
::
THREADS_PER_CTA
,
Kernel_traits
::
SMEM_BYTES
,
stream
>>>
(
launch_params
.
params
);
CHECK_CUDA
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
Kernel_traits
::
SMEM_BYTES
));
}
else
{
}
dim3
grid
(
Kernel_traits
::
CTAS_PER_ROW
*
ctas_per_col
);
auto
stream
=
launch_params
.
stream
;
dim3
block
(
Kernel_traits
::
THREADS_PER_CTA
);
auto
ctas_per_col
=
launch_params
.
params
.
ctas_per_col
;
void
*
params_
=
(
void
*
)
&
launch_params
.
params
;
cudaLaunchCooperativeKernel
((
void
*
)
kernel
,
grid
,
block
,
(
void
**
)
&
params_
,
Kernel_traits
::
SMEM_BYTES
,
stream
);
if
(
Kernel_traits
::
CTAS_PER_ROW
==
1
)
{
}
kernel
<<<
ctas_per_col
,
Kernel_traits
::
THREADS_PER_CTA
,
Kernel_traits
::
SMEM_BYTES
,
stream
>>>
(
launch_params
.
params
);
}
else
{
dim3
grid
(
Kernel_traits
::
CTAS_PER_ROW
*
ctas_per_col
);
dim3
block
(
Kernel_traits
::
THREADS_PER_CTA
);
void
*
params_
=
(
void
*
)
&
launch_params
.
params
;
cudaLaunchCooperativeKernel
((
void
*
)
kernel
,
grid
,
block
,
(
void
**
)
&
params_
,
Kernel_traits
::
SMEM_BYTES
,
stream
);
}
using
Kernel_traits_f
=
layer_norm
::
Kernel_traits_finalize
<
HIDDEN_SIZE
,
using
Kernel_traits_f
=
layer_norm
::
Kernel_traits_finalize
<
HIDDEN_SIZE
,
weight_t
,
weight_t
,
input_t
,
input_t
,
residual_t
,
residual_t
,
output_t
,
output_t
,
compute_t
,
compute_t
,
index_t
,
index_t
,
32
*
32
,
// THREADS_PER_CTA
HasColscaleConst
,
BYTES_PER_LDG_FINAL
>
;
32
*
32
,
// THREADS_PER_CTA
BYTES_PER_LDG_FINAL
>
;
auto
kernel_f
=
&
layer_norm
::
ln_bwd_finalize_kernel
<
Kernel_traits_f
,
IsEvenColsConst
>
;
kernel_f
<<<
Kernel_traits_f
::
CTAS
,
Kernel_traits_f
::
THREADS_PER_CTA
,
0
,
stream
>>>
(
launch_params
.
params
);
auto
kernel_f
=
&
layer_norm
::
ln_bwd_finalize_kernel
<
Kernel_traits_f
,
HasColscaleConst
,
IsEvenColsConst
>
;
kernel_f
<<<
Kernel_traits_f
::
CTAS
,
Kernel_traits_f
::
THREADS_PER_CTA
,
0
,
stream
>>>
(
launch_params
.
params
);
});
});
});
});
});
});
});
...
...
csrc/layer_norm/ln_fwd_kernels.cuh
View file @
ae137ed1
...
@@ -16,7 +16,7 @@
...
@@ -16,7 +16,7 @@
namespace
layer_norm
{
namespace
layer_norm
{
template
<
typename
Ktraits
,
bool
Is_dropout
,
bool
Has_residual
,
bool
Is_even_cols
>
template
<
typename
Ktraits
,
bool
Is_dropout
,
bool
Has_residual
,
bool
Has_colscale
,
bool
Is_even_cols
>
__global__
__launch_bounds__
(
Ktraits
::
THREADS_PER_CTA
)
__global__
__launch_bounds__
(
Ktraits
::
THREADS_PER_CTA
)
void
ln_fwd_kernel
(
FwdParams
params
)
{
void
ln_fwd_kernel
(
FwdParams
params
)
{
...
@@ -46,7 +46,7 @@ void ln_fwd_kernel(FwdParams params) {
...
@@ -46,7 +46,7 @@ void ln_fwd_kernel(FwdParams params) {
using
Stats
=
typename
Ktraits
::
Stats
;
using
Stats
=
typename
Ktraits
::
Stats
;
using
stats_t
=
typename
Stats
::
stats_t
;
using
stats_t
=
typename
Stats
::
stats_t
;
const
expr
bool
save_x
=
Has_residual
||
Is_dropout
||
!
(
std
::
is_same
<
input_t
,
residual_t
>::
value
);
const
bool
save_x
=
Has_residual
||
Is_dropout
||
Has_colscale
||
(
params
.
rowscale
!=
nullptr
)
||
!
(
std
::
is_same
<
input_t
,
residual_t
>::
value
);
extern
__shared__
char
smem_
[];
extern
__shared__
char
smem_
[];
...
@@ -80,12 +80,14 @@ void ln_fwd_kernel(FwdParams params) {
...
@@ -80,12 +80,14 @@ void ln_fwd_kernel(FwdParams params) {
Wvec
gamma
[
LDGS
];
Wvec
gamma
[
LDGS
];
Wvec
beta
[
LDGS
];
Wvec
beta
[
LDGS
];
Wvec
colscale
[
LDGS
];
index_t
idx
=
c
;
index_t
idx
=
c
;
#pragma unroll
#pragma unroll
for
(
int
it
=
0
;
it
<
LDGS
;
it
++
)
{
for
(
int
it
=
0
;
it
<
LDGS
;
it
++
)
{
if
(
Is_even_cols
||
(
it
<
num_valid_ldgs
))
{
if
(
Is_even_cols
||
(
it
<
num_valid_ldgs
))
{
gamma
[
it
].
load_from
(
params
.
gamma
,
idx
);
gamma
[
it
].
load_from
(
params
.
gamma
,
idx
);
beta
[
it
].
load_from
(
params
.
beta
,
idx
);
beta
[
it
].
load_from
(
params
.
beta
,
idx
);
if
(
Has_colscale
)
{
colscale
[
it
].
load_from
(
params
.
colscale
,
idx
);
}
idx
+=
VEC_COLS_PER_LDG
;
idx
+=
VEC_COLS_PER_LDG
;
}
}
}
}
...
@@ -109,13 +111,9 @@ void ln_fwd_kernel(FwdParams params) {
...
@@ -109,13 +111,9 @@ void ln_fwd_kernel(FwdParams params) {
// the more efficient curand_uniform4.
// the more efficient curand_uniform4.
mask_t
keep
=
!
Is_dropout
?
true
:
curand_uniform
(
&
state
)
<=
params
.
dropout_keep_p
;
mask_t
keep
=
!
Is_dropout
?
true
:
curand_uniform
(
&
state
)
<=
params
.
dropout_keep_p
;
compute_t
x0_ij
=
compute_t
(
x0
.
data
.
elt
[
jt
])
*
rowscale_val
;
compute_t
x0_ij
=
compute_t
(
x0
.
data
.
elt
[
jt
])
*
rowscale_val
;
compute_t
x_ij
;
x0_ij
=
keep
?
(
Is_dropout
?
x0_ij
*
params
.
dropout_scale
:
x0_ij
)
:
0.0
f
;
if
(
Has_residual
)
{
if
(
Has_colscale
)
{
x0_ij
*=
compute_t
(
colscale
[
it
].
data
.
elt
[
jt
]);
}
compute_t
x1_ij
=
compute_t
(
x1
.
data
.
elt
[
jt
]);
compute_t
x_ij
=
Has_residual
?
x0_ij
+
compute_t
(
x1
.
data
.
elt
[
jt
])
:
x0_ij
;
x_ij
=
keep
?
(
Is_dropout
?
x0_ij
*
params
.
dropout_scale
:
x0_ij
)
+
x1_ij
:
x1_ij
;
}
else
{
x_ij
=
keep
?
(
Is_dropout
?
x0_ij
*
params
.
dropout_scale
:
x0_ij
)
:
0.
f
;
}
if
(
save_x
)
{
x
.
data
.
elt
[
jt
]
=
x_ij
;
}
if
(
save_x
)
{
x
.
data
.
elt
[
jt
]
=
x_ij
;
}
xf
[
it
*
NUM_ELTS
+
jt
]
=
x_ij
;
xf
[
it
*
NUM_ELTS
+
jt
]
=
x_ij
;
if
(
Is_dropout
)
{
dmask
.
data
.
elt
[
jt
]
=
keep
;
}
if
(
Is_dropout
)
{
dmask
.
data
.
elt
[
jt
]
=
keep
;
}
...
@@ -130,8 +128,8 @@ void ln_fwd_kernel(FwdParams params) {
...
@@ -130,8 +128,8 @@ void ln_fwd_kernel(FwdParams params) {
const
index_t
num_vecs
=
params
.
cols
/
Ktraits
::
ELTS_PER_LDG
;
const
index_t
num_vecs
=
params
.
cols
/
Ktraits
::
ELTS_PER_LDG
;
const
index_t
num_full_ldgs
=
num_vecs
/
Ktraits
::
VEC_COLS_PER_LDG
;
const
index_t
num_full_ldgs
=
num_vecs
/
Ktraits
::
VEC_COLS_PER_LDG
;
const
index_t
remaining_vecs
=
num_vecs
%
Ktraits
::
VEC_COLS_PER_LDG
;
const
index_t
remaining_vecs
=
num_vecs
%
Ktraits
::
VEC_COLS_PER_LDG
;
// Need to convert to int, otherwise the subtraction will wrap around.
auto
valid_elts_in_warp_fn
=
[
num_full_ldgs
,
remaining_vecs
]
(
int
warp_n
)
->
int
{
auto
valid_elts_in_warp_fn
=
[
num_full_ldgs
,
remaining_vecs
]
(
int
warp_n
)
->
int
{
// Need to convert to int, otherwise the subtraction will wrap around.
const
index_t
valid_partial_vecs_in_warp
=
const
index_t
valid_partial_vecs_in_warp
=
std
::
min
(
std
::
max
(
int
(
remaining_vecs
)
-
int
(
warp_n
*
THREADS_PER_WARP
),
int
(
0
)),
std
::
min
(
std
::
max
(
int
(
remaining_vecs
)
-
int
(
warp_n
*
THREADS_PER_WARP
),
int
(
0
)),
int
(
THREADS_PER_WARP
));
int
(
THREADS_PER_WARP
));
...
@@ -206,45 +204,48 @@ void launch_(LaunchParams<FwdParams> &launch_params, const bool configure_params
...
@@ -206,45 +204,48 @@ void launch_(LaunchParams<FwdParams> &launch_params, const bool configure_params
BYTES_PER_LDG
BYTES_PER_LDG
>
;
>
;
bool
has_residual
=
launch_params
.
params
.
x1
!=
nullptr
;
bool
has_residual
=
launch_params
.
params
.
x1
!=
nullptr
;
bool
has_colscale
=
launch_params
.
params
.
colscale
!=
nullptr
;
bool
is_even_cols
=
launch_params
.
params
.
cols
==
HIDDEN_SIZE
;
bool
is_even_cols
=
launch_params
.
params
.
cols
==
HIDDEN_SIZE
;
BOOL_SWITCH
(
launch_params
.
params
.
dropout_keep_p
<
1.
f
,
IsDropoutConst
,
[
&
]
{
BOOL_SWITCH
(
launch_params
.
params
.
dropout_keep_p
<
1.
f
,
IsDropoutConst
,
[
&
]
{
BOOL_SWITCH
(
has_residual
,
HasResidualConst
,
[
&
]
{
BOOL_SWITCH
(
has_residual
,
HasResidualConst
,
[
&
]
{
BOOL_SWITCH
(
is_even_cols
,
IsEvenColsConst
,
[
&
]
{
BOOL_SWITCH
(
has_colscale
,
HasColscaleConst
,
[
&
]
{
auto
kernel
=
&
ln_fwd_kernel
<
Kernel_traits
,
IsDropoutConst
,
HasResidualConst
,
IsEvenColsConst
>
;
BOOL_SWITCH
(
is_even_cols
,
IsEvenColsConst
,
[
&
]
{
if
(
configure_params
)
{
auto
kernel
=
&
ln_fwd_kernel
<
Kernel_traits
,
IsDropoutConst
,
HasResidualConst
,
HasColscaleConst
,
IsEvenColsConst
>
;
int
ctas_per_sm
;
if
(
configure_params
)
{
CHECK_CUDA
(
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
int
ctas_per_sm
;
&
ctas_per_sm
,
kernel
,
Kernel_traits
::
THREADS_PER_CTA
,
Kernel_traits
::
SMEM_BYTES_FWD
));
CHECK_CUDA
(
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
launch_params
.
params
.
ctas_per_col
=
launch_params
.
props
->
multiProcessorCount
*
ctas_per_sm
/
Kernel_traits
::
CTAS_PER_ROW
;
&
ctas_per_sm
,
kernel
,
Kernel_traits
::
THREADS_PER_CTA
,
Kernel_traits
::
SMEM_BYTES_FWD
));
const
size_t
rows_per_loop
=
launch_params
.
params
.
ctas_per_col
*
Kernel_traits
::
ROWS_PER_CTA
;
launch_params
.
params
.
ctas_per_col
=
launch_params
.
props
->
multiProcessorCount
*
ctas_per_sm
/
Kernel_traits
::
CTAS_PER_ROW
;
launch_params
.
elts_per_thread
=
(
launch_params
.
params
.
rows
+
rows_per_loop
-
1
)
/
rows_per_loop
*
Kernel_traits
::
LDGS
*
Kernel_traits
::
NUM_ELTS
;
const
size_t
rows_per_loop
=
launch_params
.
params
.
ctas_per_col
*
Kernel_traits
::
ROWS_PER_CTA
;
launch_params
.
barrier_size
=
0
;
launch_params
.
elts_per_thread
=
(
launch_params
.
params
.
rows
+
rows_per_loop
-
1
)
/
rows_per_loop
*
Kernel_traits
::
LDGS
*
Kernel_traits
::
NUM_ELTS
;
launch_params
.
workspace_bytes
=
0
;
launch_params
.
barrier_size
=
0
;
if
(
Kernel_traits
::
CTAS_PER_ROW
>
1
)
{
launch_params
.
workspace_bytes
=
0
;
launch_params
.
barrier_size
=
2
*
launch_params
.
params
.
ctas_per_col
;
if
(
Kernel_traits
::
CTAS_PER_ROW
>
1
)
{
launch_params
.
workspace_bytes
=
launch_params
.
params
.
ctas_per_col
launch_params
.
barrier_size
=
2
*
launch_params
.
params
.
ctas_per_col
;
*
Kernel_traits
::
WARPS_M
launch_params
.
workspace_bytes
=
launch_params
.
params
.
ctas_per_col
*
Kernel_traits
::
CTAS_PER_ROW
*
Kernel_traits
::
WARPS_M
*
sizeof
(
typename
Kernel_traits
::
Stats
::
stats_t
)
*
Kernel_traits
::
CTAS_PER_ROW
*
2
;
*
sizeof
(
typename
Kernel_traits
::
Stats
::
stats_t
)
*
2
;
}
return
;
}
}
return
;
}
if
(
Kernel_traits
::
SMEM_BYTES_FWD
>=
48
*
1024
)
{
if
(
Kernel_traits
::
SMEM_BYTES_FWD
>=
48
*
1024
)
{
CHECK_CUDA
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
Kernel_traits
::
SMEM_BYTES_FWD
));
CHECK_CUDA
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
Kernel_traits
::
SMEM_BYTES_FWD
));
}
}
auto
stream
=
launch_params
.
stream
;
auto
stream
=
launch_params
.
stream
;
auto
ctas_per_col
=
launch_params
.
params
.
ctas_per_col
;
auto
ctas_per_col
=
launch_params
.
params
.
ctas_per_col
;
if
(
Kernel_traits
::
CTAS_PER_ROW
==
1
)
{
if
(
Kernel_traits
::
CTAS_PER_ROW
==
1
)
{
kernel
<<<
ctas_per_col
,
Kernel_traits
::
THREADS_PER_CTA
,
Kernel_traits
::
SMEM_BYTES_FWD
,
stream
>>>
(
launch_params
.
params
);
kernel
<<<
ctas_per_col
,
Kernel_traits
::
THREADS_PER_CTA
,
Kernel_traits
::
SMEM_BYTES_FWD
,
stream
>>>
(
launch_params
.
params
);
}
else
{
}
else
{
dim3
grid
(
Kernel_traits
::
CTAS_PER_ROW
*
ctas_per_col
);
dim3
grid
(
Kernel_traits
::
CTAS_PER_ROW
*
ctas_per_col
);
dim3
block
(
Kernel_traits
::
THREADS_PER_CTA
);
dim3
block
(
Kernel_traits
::
THREADS_PER_CTA
);
void
*
params_
=
(
void
*
)
&
launch_params
.
params
;
void
*
params_
=
(
void
*
)
&
launch_params
.
params
;
cudaLaunchCooperativeKernel
((
void
*
)
kernel
,
grid
,
block
,
(
void
**
)
&
params_
,
Kernel_traits
::
SMEM_BYTES_FWD
,
stream
);
cudaLaunchCooperativeKernel
((
void
*
)
kernel
,
grid
,
block
,
(
void
**
)
&
params_
,
Kernel_traits
::
SMEM_BYTES_FWD
,
stream
);
}
}
});
});
});
});
});
});
});
...
...
csrc/layer_norm/ln_kernel_traits.h
View file @
ae137ed1
...
@@ -38,6 +38,7 @@ template<
...
@@ -38,6 +38,7 @@ template<
typename
output_t_
,
typename
output_t_
,
typename
compute_t_
,
typename
compute_t_
,
typename
index_t_
,
typename
index_t_
,
bool
Has_colscale
,
uint32_t
THREADS_PER_CTA_
,
uint32_t
THREADS_PER_CTA_
,
uint32_t
BYTES_PER_LDG_
,
uint32_t
BYTES_PER_LDG_
,
typename
Base
=
Kernel_traits_base
<
HIDDEN_SIZE_
,
typename
Base
=
Kernel_traits_base
<
HIDDEN_SIZE_
,
...
@@ -69,7 +70,8 @@ struct Kernel_traits_finalize : public Base {
...
@@ -69,7 +70,8 @@ struct Kernel_traits_finalize : public Base {
// Shared memory size to coalsece the CTA result.
// Shared memory size to coalsece the CTA result.
enum
{
SMEM_BYTES_OUTPUT
=
Base
::
THREADS_PER_WARP
*
BYTES_PER_LDG
};
enum
{
SMEM_BYTES_OUTPUT
=
Base
::
THREADS_PER_WARP
*
BYTES_PER_LDG
};
// Shared memory requirement per CTA.
// Shared memory requirement per CTA.
enum
{
SMEM_BYTES_PER_CTA
=
2
*
SMEM_BYTES_TRANSPOSE
+
2
*
SMEM_BYTES_OUTPUT
};
static
constexpr
int
NUM_FACTORS
=
Has_colscale
?
3
:
2
;
enum
{
SMEM_BYTES_PER_CTA
=
NUM_FACTORS
*
SMEM_BYTES_TRANSPOSE
+
NUM_FACTORS
*
SMEM_BYTES_OUTPUT
};
// The type of the reducer.
// The type of the reducer.
using
Reducer
=
layer_norm
::
Reducer
<
compute_t_
,
1
,
1
,
1
>
;
using
Reducer
=
layer_norm
::
Reducer
<
compute_t_
,
1
,
1
,
1
>
;
...
...
csrc/layer_norm/ln_utils.cuh
View file @
ae137ed1
...
@@ -45,7 +45,7 @@ inline void check_cuda_(cudaError_t status, const char *file, int line) {
...
@@ -45,7 +45,7 @@ inline void check_cuda_(cudaError_t status, const char *file, int line) {
#define REGISTER_BWD_LAUNCHER( \
#define REGISTER_BWD_LAUNCHER( \
HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINALIZE) \
HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINALIZE) \
void ln_bwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE(LaunchParams<BwdParams> &launch_params, \
void ln_bwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE(LaunchParams<BwdParams> &launch_params, \
const bool configure_params
, const bool prenorm) {
\
const bool configure_params
) {
\
launch_<WTYPE, \
launch_<WTYPE, \
ITYPE, \
ITYPE, \
RTYPE, \
RTYPE, \
...
@@ -57,7 +57,7 @@ inline void check_cuda_(cudaError_t status, const char *file, int line) {
...
@@ -57,7 +57,7 @@ inline void check_cuda_(cudaError_t status, const char *file, int line) {
WARPS_M, \
WARPS_M, \
WARPS_N, \
WARPS_N, \
BYTES_PER_LDG, \
BYTES_PER_LDG, \
BYTES_PER_LDG_FINALIZE>(launch_params, configure_params
, prenorm);
\
BYTES_PER_LDG_FINALIZE>(launch_params, configure_params
);
\
} \
} \
static BwdRegistrar<WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, HIDDEN_SIZE> reg_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE( \
static BwdRegistrar<WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, HIDDEN_SIZE> reg_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE( \
ln_bwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE)
ln_bwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE)
...
...
flash_attn/ops/layer_norm.py
View file @
ae137ed1
# Copyright (c) 2022, Tri Dao.
# Adapted from https://github.com/NVIDIA/apex/blob/master/apex/contrib/layer_norm/layer_norm.py
# Adapted from https://github.com/NVIDIA/apex/blob/master/apex/contrib/layer_norm/layer_norm.py
import
torch
import
torch
from
torch.nn
import
init
from
torch.nn
import
init
import
dropout_layer_norm
import
dropout_layer_norm
def
_dropout_add_layer_norm_forward
(
x0
,
x1
,
gamma
,
beta
,
rowscale
,
dropout_p
,
epsilon
,
def
_dropout_add_layer_norm_forward
(
x0
,
x1
,
gamma
,
beta
,
rowscale
,
colscale
,
dropout_p
,
epsilon
,
residual_in_fp32
):
residual_in_fp32
):
""" Assume that arguments are contiguous
""" Assume that arguments are contiguous
"""
"""
...
@@ -14,133 +16,98 @@ def _dropout_add_layer_norm_forward(x0, x1, gamma, beta, rowscale, dropout_p, ep
...
@@ -14,133 +16,98 @@ def _dropout_add_layer_norm_forward(x0, x1, gamma, beta, rowscale, dropout_p, ep
x1mat
=
x1
.
view
((
-
1
,
hidden_size
))
if
x1
is
not
None
else
None
x1mat
=
x1
.
view
((
-
1
,
hidden_size
))
if
x1
is
not
None
else
None
rowscale
=
rowscale
.
view
(
-
1
)
if
rowscale
is
not
None
else
None
rowscale
=
rowscale
.
view
(
-
1
)
if
rowscale
is
not
None
else
None
zmat
,
xmat
,
dmask
,
mu
,
rsigma
=
dropout_layer_norm
.
dropout_add_ln_fwd
(
zmat
,
xmat
,
dmask
,
mu
,
rsigma
=
dropout_layer_norm
.
dropout_add_ln_fwd
(
x0mat
,
x1mat
,
gamma
,
beta
,
rowscale
,
dropout_p
,
epsilon
,
None
,
residual_in_fp32
x0mat
,
x1mat
,
gamma
,
beta
,
rowscale
,
colscale
,
dropout_p
,
epsilon
,
None
,
residual_in_fp32
)
)
# dmask is None if dropout_p == 0.0
# dmask is None if dropout_p == 0.0
# xmat is None if dropout_p == 0.0 and x1 is None and residual_dtype != input_dtype
# xmat is None if dropout_p == 0.0 and x1 is None and residual_dtype != input_dtype
return
zmat
,
xmat
if
xmat
is
not
None
else
x0mat
,
dmask
,
mu
,
rsigma
return
zmat
,
xmat
if
xmat
is
not
None
else
x0mat
,
dmask
,
mu
,
rsigma
def
_dropout_add_layer_norm_backward
(
dz
,
x
,
dmask
,
mu
,
rsigma
,
gamma
,
rowscale
,
dropout_p
,
def
_dropout_add_layer_norm_backward
(
dz
,
dx
,
x
,
x0
,
dmask
,
mu
,
rsigma
,
gamma
,
rowscale
,
colscale
,
has_residual
):
dropout_p
,
has_residual
):
""" Assume that arguments are contiguous
""" Assume that arguments are contiguous
dx == None means that it was a post-norm architecture
(x = drop(x0) + x1 was not returned in the fwd).
x0 must not be None if we have colscale.
"""
"""
# dmask is None if dropout_p == 0.0
hidden_size
=
gamma
.
numel
()
hidden_size
=
gamma
.
numel
()
xmat
=
x
.
view
((
-
1
,
hidden_size
))
xmat
=
x
.
view
((
-
1
,
hidden_size
))
dzmat
=
dz
.
view
(
xmat
.
shape
)
dzmat
=
dz
.
view
(
xmat
.
shape
)
dxmat
=
dx
.
view
(
xmat
.
shape
)
if
dx
is
not
None
else
None
x0mat
=
x0
.
view
((
-
1
,
hidden_size
))
if
x0
is
not
None
else
None
rowscale
=
rowscale
.
view
(
-
1
)
if
rowscale
is
not
None
else
None
rowscale
=
rowscale
.
view
(
-
1
)
if
rowscale
is
not
None
else
None
dx0mat
,
dx1mat
,
dgamma
,
dbeta
,
_
,
_
=
dropout_layer_norm
.
dropout_add_ln_bwd
(
colscale
=
colscale
.
view
(
-
1
)
if
colscale
is
not
None
else
None
dzmat
,
xmat
,
dmask
,
mu
,
rsigma
,
gamma
,
rowscale
,
dropout_p
,
has_residual
if
colscale
is
not
None
:
assert
x0
is
not
None
,
'x0 is required to compute the gradient of colscale'
dx0mat
,
dx1mat
,
dgamma
,
dbeta
,
_
,
_
,
*
rest
=
dropout_layer_norm
.
dropout_add_ln_bwd
(
dzmat
,
dxmat
,
xmat
,
x0mat
,
dmask
,
mu
,
rsigma
,
gamma
,
rowscale
,
colscale
,
dropout_p
,
has_residual
)
)
# dx1mat is None if not has_residual
# dx1mat is None if not has_residual
return
dx0mat
,
dx1mat
,
dgamma
,
dbeta
if
colscale
is
None
:
return
dx0mat
,
dx1mat
,
dgamma
,
dbeta
else
:
def
_dropout_add_layer_norm_prenorm_backward
(
dz
,
dx
,
x
,
dmask
,
mu
,
rsigma
,
gamma
,
rowscale
,
dcolscale
=
rest
[
0
]
dropout_p
,
has_residual
):
return
dx0mat
,
dx1mat
,
dgamma
,
dbeta
,
dcolscale
""" Assume that arguments are contiguous
"""
hidden_size
=
gamma
.
numel
()
xmat
=
x
.
view
((
-
1
,
hidden_size
))
dzmat
=
dz
.
view
(
xmat
.
shape
)
dxmat
=
dx
.
view
(
xmat
.
shape
)
rowscale
=
rowscale
.
view
(
-
1
)
if
rowscale
is
not
None
else
None
dx0mat
,
dx1mat
,
dgamma
,
dbeta
,
_
,
_
=
dropout_layer_norm
.
dropout_add_ln_prenorm_bwd
(
dzmat
,
dxmat
,
xmat
,
dmask
,
mu
,
rsigma
,
gamma
,
rowscale
,
dropout_p
,
has_residual
)
return
dx0mat
,
dx1mat
,
dgamma
,
dbeta
class
DropoutAddLayerNormF
N
(
torch
.
autograd
.
Function
):
class
DropoutAddLayerNormF
n
(
torch
.
autograd
.
Function
):
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
x0
,
x1
,
gamma
,
beta
,
rowscale
,
dropout_p
,
epsilon
,
residual_in_fp32
,
def
forward
(
ctx
,
x0
,
x1
,
gamma
,
beta
,
rowscale
,
colscale
,
dropout_p
,
epsilon
,
residual_in_fp32
,
return_dmask
=
False
):
prenorm
=
False
,
return_dmask
=
False
):
x0
=
x0
.
contiguous
()
x0
=
x0
.
contiguous
()
x1
=
x1
.
contiguous
()
if
x1
is
not
None
else
None
x1
=
x1
.
contiguous
()
if
x1
is
not
None
else
None
gamma
=
gamma
.
contiguous
()
gamma
=
gamma
.
contiguous
()
beta
=
beta
.
contiguous
()
beta
=
beta
.
contiguous
()
rowscale
=
rowscale
.
contiguous
()
if
rowscale
is
not
None
else
None
rowscale
=
rowscale
.
contiguous
()
if
rowscale
is
not
None
else
None
colscale
=
colscale
.
contiguous
()
if
colscale
is
not
None
else
None
zmat
,
xmat
,
dmask
,
mu
,
rsigma
=
_dropout_add_layer_norm_forward
(
zmat
,
xmat
,
dmask
,
mu
,
rsigma
=
_dropout_add_layer_norm_forward
(
x0
,
x1
,
gamma
,
beta
,
rowscale
,
dropout_p
,
epsilon
,
residual_in_fp32
x0
,
x1
,
gamma
,
beta
,
rowscale
,
colscale
,
dropout_p
,
epsilon
,
residual_in_fp32
)
)
ctx
.
save_for_backward
(
xmat
.
view
(
x0
.
shape
),
dmask
,
gamma
,
mu
,
rsigma
,
rowscale
)
# Only need to save x0 if we need to compute gradient wrt colscale
x0_saved
=
x0
if
colscale
is
not
None
else
None
ctx
.
save_for_backward
(
xmat
.
view
(
x0
.
shape
),
x0
,
dmask
,
gamma
,
mu
,
rsigma
,
rowscale
,
colscale
)
ctx
.
prenorm
=
prenorm
ctx
.
dropout_p
=
dropout_p
ctx
.
dropout_p
=
dropout_p
ctx
.
has_residual
=
x1
is
not
None
ctx
.
has_residual
=
x1
is
not
None
if
not
return_dmask
:
if
not
return_dmask
:
return
zmat
.
view
(
x0
.
shape
)
return
(
zmat
.
view
(
x0
.
shape
)
if
not
prenorm
else
(
zmat
.
view
(
x0
.
shape
),
xmat
.
view
(
x0
.
shape
)))
else
:
else
:
dmask
=
(
dmask
.
view
(
x0
.
shape
)
if
dropout_p
>
0.
dmask
=
(
dmask
.
view
(
x0
.
shape
)
if
dropout_p
>
0.
else
torch
.
ones
(
x0
.
shape
,
dtype
=
torch
.
uint8
,
device
=
x0
.
device
))
else
torch
.
ones
(
x0
.
shape
,
dtype
=
torch
.
uint8
,
device
=
x0
.
device
))
ctx
.
mark_non_differentiable
(
dmask
)
ctx
.
mark_non_differentiable
(
dmask
)
return
zmat
.
view
(
x0
.
shape
),
dmask
return
((
zmat
.
view
(
x0
.
shape
),
dmask
)
if
not
prenorm
else
(
zmat
.
view
(
x0
.
shape
),
xmat
.
view
(
x0
.
shape
),
dmask
))
@
staticmethod
@
staticmethod
def
backward
(
ctx
,
dz
,
*
args
):
def
backward
(
ctx
,
dz
,
*
args
):
# assert dz.is_contiguous()
# assert dz.is_contiguous()
dz
=
dz
.
contiguous
()
# this happens!
dz
=
dz
.
contiguous
()
# this happens!
x
,
dmask
,
gamma
,
mu
,
rsigma
,
rowscale
=
ctx
.
saved_tensors
dx
=
args
[
0
].
contiguous
()
if
ctx
.
prenorm
else
None
dropout_p
=
ctx
.
dropout_p
x
,
x0
,
dmask
,
gamma
,
mu
,
rsigma
,
rowscale
,
colscale
=
ctx
.
saved_tensors
has_residual
=
ctx
.
has_residual
# x0 is None if colscale is None
dx0mat
,
dx1mat
,
dgamma
,
dbeta
=
_dropout_add_layer_norm_backward
(
dz
,
x
,
dmask
,
mu
,
rsigma
,
gamma
,
rowscale
,
dropout_p
,
has_residual
)
dx0
=
dx0mat
.
view
(
x
.
shape
)
dx1
=
dx1mat
.
view
(
x
.
shape
)
if
dx1mat
is
not
None
else
None
return
dx0
,
dx1
,
dgamma
,
dbeta
,
None
,
None
,
None
,
None
,
None
class
DropoutAddLayerNormPrenormFN
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
x0
,
x1
,
gamma
,
beta
,
rowscale
,
dropout_p
,
epsilon
,
residual_in_fp32
,
return_dmask
=
False
):
x0
=
x0
.
contiguous
()
x1
=
x1
.
contiguous
()
if
x1
is
not
None
else
None
gamma
=
gamma
.
contiguous
()
beta
=
beta
.
contiguous
()
rowscale
=
rowscale
.
contiguous
()
if
rowscale
is
not
None
else
None
zmat
,
xmat
,
dmask
,
mu
,
rsigma
=
_dropout_add_layer_norm_forward
(
x0
,
x1
,
gamma
,
beta
,
rowscale
,
dropout_p
,
epsilon
,
residual_in_fp32
)
ctx
.
save_for_backward
(
xmat
.
view
(
x0
.
shape
),
dmask
,
gamma
,
mu
,
rsigma
,
rowscale
)
ctx
.
dropout_p
=
dropout_p
ctx
.
has_residual
=
x1
is
not
None
if
not
return_dmask
:
return
zmat
.
view
(
x0
.
shape
),
xmat
.
view
(
x0
.
shape
)
else
:
dmask
=
(
dmask
.
view
(
x0
.
shape
)
if
dropout_p
>
0.
else
torch
.
ones
(
x0
.
shape
,
dtype
=
torch
.
uint8
,
device
=
x0
.
device
))
ctx
.
mark_non_differentiable
(
dmask
)
return
zmat
.
view
(
x0
.
shape
),
xmat
.
view
(
x0
.
shape
),
dmask
@
staticmethod
def
backward
(
ctx
,
dz
,
dx
,
*
args
):
# assert dz.is_contiguous()
dz
=
dz
.
contiguous
()
# this happens!
dx
=
dx
.
contiguous
()
# this happens!
x
,
dmask
,
gamma
,
mu
,
rsigma
,
rowscale
=
ctx
.
saved_tensors
dropout_p
=
ctx
.
dropout_p
dropout_p
=
ctx
.
dropout_p
has_residual
=
ctx
.
has_residual
has_residual
=
ctx
.
has_residual
dx0mat
,
dx1mat
,
dgamma
,
dbeta
=
_dropout_add_layer_norm_
prenorm_
backward
(
dx0mat
,
dx1mat
,
dgamma
,
dbeta
,
*
rest
=
_dropout_add_layer_norm_backward
(
dz
,
dx
,
x
,
dmask
,
mu
,
rsigma
,
gamma
,
rowscale
,
dropout_p
,
has_residual
dz
,
dx
,
x
,
x0
,
dmask
,
mu
,
rsigma
,
gamma
,
rowscale
,
colscale
,
dropout_p
,
has_residual
)
)
dx0
=
dx0mat
.
view
(
x
.
shape
)
dx0
=
dx0mat
.
view
(
x
.
shape
)
dx1
=
dx1mat
.
view
(
x
.
shape
)
if
dx1mat
is
not
None
else
None
dx1
=
dx1mat
.
view
(
x
.
shape
)
if
dx1mat
is
not
None
else
None
return
dx0
,
dx1
,
dgamma
,
dbeta
,
None
,
None
,
None
,
None
,
None
dcolscale
=
rest
[
0
]
if
colscale
is
not
None
else
None
return
dx0
,
dx1
,
dgamma
,
dbeta
,
None
,
dcolscale
,
None
,
None
,
None
,
None
,
None
def
dropout_add_layer_norm
(
x0
,
x1
,
weight
,
bias
,
dropout_p
,
epsilon
,
rowscale
=
None
,
def
dropout_add_layer_norm
(
x0
,
x1
,
weight
,
bias
,
dropout_p
,
epsilon
,
rowscale
=
None
,
layerscale
=
None
,
prenorm
=
False
,
residual_in_fp32
=
False
,
prenorm
=
False
,
residual_in_fp32
=
False
,
return_dropout_mask
=
False
):
return_dropout_mask
=
False
):
"""residual_in_fp32 only has an effect if x1 is None.
"""residual_in_fp32 only has an effect if x1 is None.
Otherwise residual dtype is x1.dtype.
Otherwise residual dtype is x1.dtype.
"""
"""
args
=
(
x0
,
x1
,
weight
,
bias
,
rowscale
,
dropout_p
,
epsilon
,
residual_in_fp32
,
return
DropoutAddLayerNormFn
.
apply
(
return_dropout_mask
)
x0
,
x1
,
weight
,
bias
,
rowscale
,
layerscale
,
dropout_p
,
epsilon
,
residual_in_fp32
,
prenorm
,
if
not
prenorm
:
return_dropout_mask
return
DropoutAddLayerNormFN
.
apply
(
*
args
)
)
else
:
return
DropoutAddLayerNormPrenormFN
.
apply
(
*
args
)
class
DropoutAddLayerNorm
(
torch
.
nn
.
Module
):
class
DropoutAddLayerNorm
(
torch
.
nn
.
Module
):
...
...
tests/ops/test_dropout_layer_norm.py
View file @
ae137ed1
...
@@ -11,6 +11,7 @@ from flash_attn.ops.layer_norm import DropoutAddLayerNorm, dropout_add_layer_nor
...
@@ -11,6 +11,7 @@ from flash_attn.ops.layer_norm import DropoutAddLayerNorm, dropout_add_layer_nor
is_sm8x
=
torch
.
cuda
.
get_device_capability
(
'cuda'
)[
0
]
>=
8
is_sm8x
=
torch
.
cuda
.
get_device_capability
(
'cuda'
)[
0
]
>=
8
@
pytest
.
mark
.
parametrize
(
'has_colscale'
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
'has_rowscale'
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
'has_rowscale'
,
[
True
,
False
])
# @pytest.mark.parametrize('has_rowscale', [True])
# @pytest.mark.parametrize('has_rowscale', [True])
@
pytest
.
mark
.
parametrize
(
'has_residual'
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
'has_residual'
,
[
True
,
False
])
...
@@ -26,12 +27,9 @@ is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8
...
@@ -26,12 +27,9 @@ is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8
# @pytest.mark.parametrize('input_dtype,residual_dtype', [(torch.float16, torch.float32)])
# @pytest.mark.parametrize('input_dtype,residual_dtype', [(torch.float16, torch.float32)])
@
pytest
.
mark
.
parametrize
(
'hidden_size'
,
[
192
,
256
,
384
,
768
,
1024
,
1280
,
1536
,
1600
,
2048
,
2560
,
3000
,
3072
,
4096
,
5120
,
6144
])
@
pytest
.
mark
.
parametrize
(
'hidden_size'
,
[
192
,
256
,
384
,
768
,
1024
,
1280
,
1536
,
1600
,
2048
,
2560
,
3000
,
3072
,
4096
,
5120
,
6144
])
def
test_dropout_layer_norm_training
(
hidden_size
,
input_dtype
,
residual_dtype
,
weight_dtype
,
def
test_dropout_layer_norm_training
(
hidden_size
,
input_dtype
,
residual_dtype
,
weight_dtype
,
dropout_p
,
has_residual
,
has_rowscale
):
dropout_p
,
has_residual
,
has_rowscale
,
has_colscale
):
if
weight_dtype
==
torch
.
float16
and
input_dtype
==
torch
.
bfloat16
:
if
weight_dtype
==
torch
.
float16
and
input_dtype
==
torch
.
bfloat16
:
pytest
.
skip
()
# Not supported
pytest
.
skip
()
# Not supported
# Backward numerical error is high, and this case isn't used
if
has_rowscale
and
not
has_residual
:
pytest
.
skip
()
device
=
'cuda'
device
=
'cuda'
# rtol, atol = (1e-5, 1e-6) if input_dtype == torch.float32 else (1e-3, 1e-4)
# rtol, atol = (1e-5, 1e-6) if input_dtype == torch.float32 else (1e-3, 1e-4)
rtol
,
atol
=
(
1e-3
,
1e-4
)
rtol
,
atol
=
(
1e-3
,
1e-4
)
...
@@ -43,6 +41,12 @@ def test_dropout_layer_norm_training(hidden_size, input_dtype, residual_dtype, w
...
@@ -43,6 +41,12 @@ def test_dropout_layer_norm_training(hidden_size, input_dtype, residual_dtype, w
requires_grad
=
True
)
requires_grad
=
True
)
x0
=
x0_pt
.
detach
().
clone
().
requires_grad_
()
x0
=
x0_pt
.
detach
().
clone
().
requires_grad_
()
x0_ref
=
x0_pt
.
detach
().
clone
().
float
().
requires_grad_
()
x0_ref
=
x0_pt
.
detach
().
clone
().
float
().
requires_grad_
()
if
has_colscale
:
colscale
=
torch
.
randn
(
hidden_size
,
device
=
device
,
dtype
=
weight_dtype
,
requires_grad
=
True
)
colscale_pt
=
colscale
.
detach
().
clone
().
requires_grad_
()
colscale_ref
=
colscale
.
detach
().
clone
().
float
().
requires_grad_
()
else
:
colscale
=
None
if
has_residual
:
if
has_residual
:
x1_pt
=
torch
.
randn_like
(
x0
,
dtype
=
residual_dtype
,
requires_grad
=
True
)
x1_pt
=
torch
.
randn_like
(
x0
,
dtype
=
residual_dtype
,
requires_grad
=
True
)
x1
=
x1_pt
.
detach
().
clone
().
requires_grad_
()
x1
=
x1_pt
.
detach
().
clone
().
requires_grad_
()
...
@@ -59,6 +63,9 @@ def test_dropout_layer_norm_training(hidden_size, input_dtype, residual_dtype, w
...
@@ -59,6 +63,9 @@ def test_dropout_layer_norm_training(hidden_size, input_dtype, residual_dtype, w
rowscale
=
None
rowscale
=
None
x0_scaled_pt
=
x0_pt
x0_scaled_pt
=
x0_pt
x0_scaled_ref
=
x0_ref
x0_scaled_ref
=
x0_ref
if
has_colscale
:
x0_scaled_pt
=
x0_scaled_pt
*
colscale_pt
x0_scaled_ref
=
x0_scaled_ref
*
colscale_ref
model_pt
=
torch
.
nn
.
LayerNorm
(
hidden_size
,
device
=
device
,
dtype
=
weight_dtype
)
model_pt
=
torch
.
nn
.
LayerNorm
(
hidden_size
,
device
=
device
,
dtype
=
weight_dtype
)
torch
.
nn
.
init
.
normal_
(
model_pt
.
weight
)
torch
.
nn
.
init
.
normal_
(
model_pt
.
weight
)
torch
.
nn
.
init
.
normal_
(
model_pt
.
bias
)
torch
.
nn
.
init
.
normal_
(
model_pt
.
bias
)
...
@@ -71,7 +78,7 @@ def test_dropout_layer_norm_training(hidden_size, input_dtype, residual_dtype, w
...
@@ -71,7 +78,7 @@ def test_dropout_layer_norm_training(hidden_size, input_dtype, residual_dtype, w
model_ref
.
bias
.
copy_
(
model_pt
.
bias
)
model_ref
.
bias
.
copy_
(
model_pt
.
bias
)
residual_in_fp32
=
(
not
has_residual
)
and
residual_dtype
==
torch
.
float32
residual_in_fp32
=
(
not
has_residual
)
and
residual_dtype
==
torch
.
float32
out
,
dmask
=
dropout_add_layer_norm
(
x0
,
x1
,
model
.
weight
,
model
.
bias
,
model
.
p
,
out
,
dmask
=
dropout_add_layer_norm
(
x0
,
x1
,
model
.
weight
,
model
.
bias
,
model
.
p
,
model
.
epsilon
,
rowscale
=
rowscale
,
model
.
epsilon
,
rowscale
=
rowscale
,
layerscale
=
colscale
,
residual_in_fp32
=
residual_in_fp32
,
return_dropout_mask
=
True
)
residual_in_fp32
=
residual_in_fp32
,
return_dropout_mask
=
True
)
assert
out
.
dtype
==
input_dtype
assert
out
.
dtype
==
input_dtype
print
(
f
'Actual dropout fraction:
{
1
-
dmask
.
float
().
mean
().
item
()
}
'
)
print
(
f
'Actual dropout fraction:
{
1
-
dmask
.
float
().
mean
().
item
()
}
'
)
...
@@ -94,6 +101,8 @@ def test_dropout_layer_norm_training(hidden_size, input_dtype, residual_dtype, w
...
@@ -94,6 +101,8 @@ def test_dropout_layer_norm_training(hidden_size, input_dtype, residual_dtype, w
assert
(
x1
.
grad
-
x1_ref
.
grad
).
abs
().
max
()
<=
4
*
(
x1_pt
.
grad
-
x1_ref
.
grad
).
abs
().
max
()
+
1e-4
assert
(
x1
.
grad
-
x1_ref
.
grad
).
abs
().
max
()
<=
4
*
(
x1_pt
.
grad
-
x1_ref
.
grad
).
abs
().
max
()
+
1e-4
assert
(
model
.
weight
.
grad
-
model_ref
.
weight
.
grad
).
abs
().
max
()
<=
2
*
(
model_pt
.
weight
.
grad
-
model_ref
.
weight
.
grad
).
abs
().
max
()
+
3e-5
assert
(
model
.
weight
.
grad
-
model_ref
.
weight
.
grad
).
abs
().
max
()
<=
2
*
(
model_pt
.
weight
.
grad
-
model_ref
.
weight
.
grad
).
abs
().
max
()
+
3e-5
assert
(
model
.
bias
.
grad
-
model_ref
.
bias
.
grad
).
abs
().
max
()
<=
2
*
(
model_pt
.
bias
.
grad
-
model_ref
.
bias
.
grad
).
abs
().
max
()
+
3e-5
assert
(
model
.
bias
.
grad
-
model_ref
.
bias
.
grad
).
abs
().
max
()
<=
2
*
(
model_pt
.
bias
.
grad
-
model_ref
.
bias
.
grad
).
abs
().
max
()
+
3e-5
if
has_colscale
:
assert
(
colscale
.
grad
-
colscale_ref
.
grad
).
abs
().
max
()
<=
2
*
(
colscale_pt
.
grad
-
colscale_ref
.
grad
).
abs
().
max
()
+
2e-4
@
pytest
.
mark
.
parametrize
(
'weight_dtype'
,
[
torch
.
float32
,
torch
.
float16
])
@
pytest
.
mark
.
parametrize
(
'weight_dtype'
,
[
torch
.
float32
,
torch
.
float16
])
...
@@ -139,6 +148,7 @@ def test_dropout_layer_norm_eval(hidden_size, input_dtype, residual_dtype, weigh
...
@@ -139,6 +148,7 @@ def test_dropout_layer_norm_eval(hidden_size, input_dtype, residual_dtype, weigh
assert
(
out
-
out_ref
).
abs
().
max
()
<=
4
*
(
out_pt
-
out_ref
).
abs
().
max
()
+
1e-4
assert
(
out
-
out_ref
).
abs
().
max
()
<=
4
*
(
out_pt
-
out_ref
).
abs
().
max
()
+
1e-4
@
pytest
.
mark
.
parametrize
(
'has_colscale'
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
'has_rowscale'
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
'has_rowscale'
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
'has_residual'
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
'has_residual'
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
'dropout_p'
,
[
0.37
,
0.0
])
@
pytest
.
mark
.
parametrize
(
'dropout_p'
,
[
0.37
,
0.0
])
...
@@ -147,20 +157,17 @@ def test_dropout_layer_norm_eval(hidden_size, input_dtype, residual_dtype, weigh
...
@@ -147,20 +157,17 @@ def test_dropout_layer_norm_eval(hidden_size, input_dtype, residual_dtype, weigh
[(
torch
.
float16
,
torch
.
float16
),
(
torch
.
float16
,
torch
.
float32
),
[(
torch
.
float16
,
torch
.
float16
),
(
torch
.
float16
,
torch
.
float32
),
(
torch
.
float32
,
torch
.
float32
)]
(
torch
.
float32
,
torch
.
float32
)]
+
([(
torch
.
bfloat16
,
torch
.
bfloat16
),
(
torch
.
bfloat16
,
torch
.
float32
)]
if
is_sm8x
else
[]))
+
([(
torch
.
bfloat16
,
torch
.
bfloat16
),
(
torch
.
bfloat16
,
torch
.
float32
)]
if
is_sm8x
else
[]))
# @pytest.mark.parametrize('has_colscale', [True])
# @pytest.mark.parametrize('has_rowscale', [False])
# @pytest.mark.parametrize('has_rowscale', [False])
# @pytest.mark.parametrize('has_residual', [
Tru
e])
# @pytest.mark.parametrize('has_residual', [
Fals
e])
# @pytest.mark.parametrize('dropout_p', [0.0])
# @pytest.mark.parametrize('dropout_p', [0.0])
# @pytest.mark.parametrize('weight_dtype', [torch.float32])
# @pytest.mark.parametrize('weight_dtype', [torch.float32])
# @pytest.mark.parametrize('input_dtype,residual_dtype', [(torch.float32, torch.float32)])
# @pytest.mark.parametrize('input_dtype,residual_dtype', [(torch.float32, torch.float32)])
# @pytest.mark.parametrize('hidden_size', [768, 1024, 1280, 1536, 1600, 2048, 2560, 3072, 4096, 5120])
@
pytest
.
mark
.
parametrize
(
'hidden_size'
,
[
192
,
256
,
384
,
768
,
1024
,
1280
,
1536
,
1600
,
2048
,
2560
,
3000
,
3072
,
4096
,
5120
,
6144
])
@
pytest
.
mark
.
parametrize
(
'hidden_size'
,
[
192
,
256
,
384
,
768
,
1024
,
1280
,
1536
,
1600
,
2048
,
2560
,
3000
,
3072
,
4096
,
5120
,
6144
])
def
test_dropout_layer_norm_prenorm_training
(
hidden_size
,
input_dtype
,
residual_dtype
,
weight_dtype
,
def
test_dropout_layer_norm_prenorm_training
(
hidden_size
,
input_dtype
,
residual_dtype
,
weight_dtype
,
dropout_p
,
has_residual
,
has_rowscale
):
dropout_p
,
has_residual
,
has_rowscale
,
has_colscale
):
if
weight_dtype
==
torch
.
float16
and
input_dtype
==
torch
.
bfloat16
:
if
weight_dtype
==
torch
.
float16
and
input_dtype
==
torch
.
bfloat16
:
pytest
.
skip
()
# Not supported
pytest
.
skip
()
# Not supported
# Backward numerical error is high, and this case isn't used
if
has_rowscale
and
not
has_residual
:
pytest
.
skip
()
device
=
'cuda'
device
=
'cuda'
# rtol, atol = (1e-5, 1e-6) if input_dtype == torch.float32 else (1e-3, 1e-4)
# rtol, atol = (1e-5, 1e-6) if input_dtype == torch.float32 else (1e-3, 1e-4)
rtol
,
atol
=
(
1e-3
,
2e-4
)
rtol
,
atol
=
(
1e-3
,
2e-4
)
...
@@ -172,6 +179,12 @@ def test_dropout_layer_norm_prenorm_training(hidden_size, input_dtype, residual_
...
@@ -172,6 +179,12 @@ def test_dropout_layer_norm_prenorm_training(hidden_size, input_dtype, residual_
requires_grad
=
True
)
requires_grad
=
True
)
x0
=
x0_pt
.
detach
().
clone
().
requires_grad_
()
x0
=
x0_pt
.
detach
().
clone
().
requires_grad_
()
x0_ref
=
x0_pt
.
detach
().
clone
().
float
().
requires_grad_
()
x0_ref
=
x0_pt
.
detach
().
clone
().
float
().
requires_grad_
()
if
has_colscale
:
colscale
=
torch
.
randn
(
hidden_size
,
device
=
device
,
dtype
=
weight_dtype
,
requires_grad
=
True
)
colscale_pt
=
colscale
.
detach
().
clone
().
requires_grad_
()
colscale_ref
=
colscale
.
detach
().
clone
().
float
().
requires_grad_
()
else
:
colscale
=
None
if
has_residual
:
if
has_residual
:
x1_pt
=
torch
.
randn_like
(
x0
,
dtype
=
residual_dtype
,
requires_grad
=
True
)
x1_pt
=
torch
.
randn_like
(
x0
,
dtype
=
residual_dtype
,
requires_grad
=
True
)
x1
=
x1_pt
.
detach
().
clone
().
requires_grad_
()
x1
=
x1_pt
.
detach
().
clone
().
requires_grad_
()
...
@@ -188,6 +201,9 @@ def test_dropout_layer_norm_prenorm_training(hidden_size, input_dtype, residual_
...
@@ -188,6 +201,9 @@ def test_dropout_layer_norm_prenorm_training(hidden_size, input_dtype, residual_
rowscale
=
None
rowscale
=
None
x0_scaled_pt
=
x0_pt
x0_scaled_pt
=
x0_pt
x0_scaled_ref
=
x0_ref
x0_scaled_ref
=
x0_ref
if
has_colscale
:
x0_scaled_pt
=
x0_scaled_pt
*
colscale_pt
x0_scaled_ref
=
x0_scaled_ref
*
colscale_ref
model_pt
=
torch
.
nn
.
LayerNorm
(
hidden_size
,
device
=
device
,
dtype
=
weight_dtype
)
model_pt
=
torch
.
nn
.
LayerNorm
(
hidden_size
,
device
=
device
,
dtype
=
weight_dtype
)
model_ref
=
torch
.
nn
.
LayerNorm
(
hidden_size
,
device
=
device
,
dtype
=
torch
.
float32
)
model_ref
=
torch
.
nn
.
LayerNorm
(
hidden_size
,
device
=
device
,
dtype
=
torch
.
float32
)
model
=
DropoutAddLayerNorm
(
hidden_size
,
prenorm
=
True
,
p
=
dropout_p
,
device
=
device
,
model
=
DropoutAddLayerNorm
(
hidden_size
,
prenorm
=
True
,
p
=
dropout_p
,
device
=
device
,
...
@@ -199,7 +215,8 @@ def test_dropout_layer_norm_prenorm_training(hidden_size, input_dtype, residual_
...
@@ -199,7 +215,8 @@ def test_dropout_layer_norm_prenorm_training(hidden_size, input_dtype, residual_
model_ref
.
bias
.
copy_
(
model_pt
.
bias
)
model_ref
.
bias
.
copy_
(
model_pt
.
bias
)
residual_in_fp32
=
(
not
has_residual
)
and
residual_dtype
==
torch
.
float32
residual_in_fp32
=
(
not
has_residual
)
and
residual_dtype
==
torch
.
float32
out
,
residual
,
dmask
=
dropout_add_layer_norm
(
x0
,
x1
,
model
.
weight
,
model
.
bias
,
model
.
p
,
out
,
residual
,
dmask
=
dropout_add_layer_norm
(
x0
,
x1
,
model
.
weight
,
model
.
bias
,
model
.
p
,
model
.
epsilon
,
rowscale
=
rowscale
,
prenorm
=
True
,
model
.
epsilon
,
rowscale
=
rowscale
,
layerscale
=
colscale
,
prenorm
=
True
,
residual_in_fp32
=
residual_in_fp32
,
residual_in_fp32
=
residual_in_fp32
,
return_dropout_mask
=
True
)
return_dropout_mask
=
True
)
print
(
f
'Actual dropout fraction:
{
1
-
dmask
.
float
().
mean
().
item
()
}
'
)
print
(
f
'Actual dropout fraction:
{
1
-
dmask
.
float
().
mean
().
item
()
}
'
)
...
@@ -225,6 +242,8 @@ def test_dropout_layer_norm_prenorm_training(hidden_size, input_dtype, residual_
...
@@ -225,6 +242,8 @@ def test_dropout_layer_norm_prenorm_training(hidden_size, input_dtype, residual_
assert
(
x1
.
grad
-
x1_ref
.
grad
).
abs
().
max
()
<=
4
*
(
x1_pt
.
grad
-
x1_ref
.
grad
).
abs
().
max
()
+
1e-4
assert
(
x1
.
grad
-
x1_ref
.
grad
).
abs
().
max
()
<=
4
*
(
x1_pt
.
grad
-
x1_ref
.
grad
).
abs
().
max
()
+
1e-4
assert
(
model
.
weight
.
grad
-
model_ref
.
weight
.
grad
).
abs
().
max
()
<=
2
*
(
model_pt
.
weight
.
grad
-
model_ref
.
weight
.
grad
).
abs
().
max
()
+
2e-4
assert
(
model
.
weight
.
grad
-
model_ref
.
weight
.
grad
).
abs
().
max
()
<=
2
*
(
model_pt
.
weight
.
grad
-
model_ref
.
weight
.
grad
).
abs
().
max
()
+
2e-4
assert
(
model
.
bias
.
grad
-
model_ref
.
bias
.
grad
).
abs
().
max
()
<=
2
*
(
model_pt
.
bias
.
grad
-
model_ref
.
bias
.
grad
).
abs
().
max
()
+
2e-4
assert
(
model
.
bias
.
grad
-
model_ref
.
bias
.
grad
).
abs
().
max
()
<=
2
*
(
model_pt
.
bias
.
grad
-
model_ref
.
bias
.
grad
).
abs
().
max
()
+
2e-4
if
has_colscale
:
assert
(
colscale
.
grad
-
colscale_ref
.
grad
).
abs
().
max
()
<=
2
*
(
colscale_pt
.
grad
-
colscale_ref
.
grad
).
abs
().
max
()
+
2e-4
@
pytest
.
mark
.
parametrize
(
'weight_dtype'
,
[
torch
.
float32
,
torch
.
float16
])
@
pytest
.
mark
.
parametrize
(
'weight_dtype'
,
[
torch
.
float32
,
torch
.
float16
])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment