Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
e568bfdb
Commit
e568bfdb
authored
Aug 31, 2023
by
Qianfeng Zhang
Browse files
Add examples to test dropout device-op and kernel
parent
bcbeed99
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
617 additions
and
27 deletions
+617
-27
example/52_flash_atten_bias/CMakeLists.txt
example/52_flash_atten_bias/CMakeLists.txt
+4
-1
example/52_flash_atten_bias/batched_multihead_attention_bias_forward_v2_zcheck.cpp
...as/batched_multihead_attention_bias_forward_v2_zcheck.cpp
+333
-0
example/52_flash_atten_bias/run_batched_multihead_attention_bias_forward_zcheck.inc
...s/run_batched_multihead_attention_bias_forward_zcheck.inc
+251
-0
include/ck/tensor_operation/gpu/device/impl/device_batched_dropout.hpp
...nsor_operation/gpu/device/impl/device_batched_dropout.hpp
+17
-26
include/ck/tensor_operation/gpu/grid/gridwise_batched_dropout.hpp
...ck/tensor_operation/gpu/grid/gridwise_batched_dropout.hpp
+12
-0
No files found.
example/52_flash_atten_bias/CMakeLists.txt
View file @
e568bfdb
...
@@ -3,3 +3,6 @@ add_example_executable(example_grouped_multihead_attention_bias_forward_v2 group
...
@@ -3,3 +3,6 @@ add_example_executable(example_grouped_multihead_attention_bias_forward_v2 group
add_example_executable
(
example_batched_multihead_attention_bias_backward_v2 batched_multihead_attention_bias_backward_v2.cpp
)
add_example_executable
(
example_batched_multihead_attention_bias_backward_v2 batched_multihead_attention_bias_backward_v2.cpp
)
add_example_executable
(
example_grouped_multihead_attention_bias_backward_v2 grouped_multihead_attention_bias_backward_v2.cpp
)
add_example_executable
(
example_grouped_multihead_attention_bias_backward_v2 grouped_multihead_attention_bias_backward_v2.cpp
)
add_example_executable
(
example_batched_multihead_attention_bias_forward_v2_zcheck batched_multihead_attention_bias_forward_v2_zcheck.cpp
)
example/52_flash_atten_bias/batched_multihead_attention_bias_forward_v2_zcheck.cpp
0 → 100644
View file @
e568bfdb
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
/*
Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g_k_n) * B1_g_n_o
|-----------------|
Gemm0
|-------------------------------------|
Gemm1
*/
#define DIM 128 // DIM should be a multiple of 8.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_mha_fwd_xdl_cshuffle_v2.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_dropout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/utility/host_common_util.hpp"
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
F16
=
ck
::
half_t
;
using
BF16
=
ck
::
bhalf_t
;
using
F32
=
float
;
using
U16
=
unsigned
short
;
using
INT32
=
int32_t
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
DataType
=
F16
;
using
GemmDataType
=
F16
;
using
ADataType
=
DataType
;
using
B0DataType
=
DataType
;
using
B1DataType
=
DataType
;
using
AccDataType
=
F32
;
using
CShuffleDataType
=
F32
;
using
CDataType
=
DataType
;
using
ZDataType
=
U16
;
// INT32
using
LSEDataType
=
F32
;
using
Acc0BiasDataType
=
F16
;
using
Acc1BiasDataType
=
void
;
static
constexpr
ck
::
index_t
NumDimG
=
2
;
static
constexpr
ck
::
index_t
NumDimM
=
1
;
static
constexpr
ck
::
index_t
NumDimN
=
1
;
static
constexpr
ck
::
index_t
NumDimK
=
1
;
static
constexpr
ck
::
index_t
NumDimO
=
1
;
using
AElementOp
=
PassThrough
;
using
B0ElementOp
=
PassThrough
;
using
Acc0ElementOp
=
ck
::
tensor_operation
::
element_wise
::
Scale
;
using
B1ElementOp
=
PassThrough
;
using
CElementOp
=
PassThrough
;
static
constexpr
auto
GemmSpec
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKOPadding
;
static
constexpr
auto
MaskingSpec
=
ck
::
tensor_operation
::
device
::
MaskingSpecialization
::
MaskDisabled
;
static
constexpr
auto
TensorSpecA
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecB0
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecB1
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecC
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
bool
Deterministic
=
false
;
#if(DIM <= 32)
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
ADataType
,
B0DataType
,
B1DataType
,
CDataType
,
GemmDataType
,
ZDataType
,
LSEDataType
,
Acc0BiasDataType
,
Acc1BiasDataType
,
AccDataType
,
CShuffleDataType
,
AElementOp
,
B0ElementOp
,
Acc0ElementOp
,
B1ElementOp
,
CElementOp
,
GemmSpec
,
TensorSpecA
,
TensorSpecB0
,
TensorSpecB1
,
TensorSpecC
,
1
,
256
,
128
,
// MPerBlock
128
,
// NPerBlock
32
,
// KPerBlock
32
,
// Gemm1NPerBlock
32
,
// Gemm1KPerBlock
8
,
// AK1
8
,
// BK1
2
,
// B1K1
32
,
// MPerXDL
32
,
// NPerXDL
1
,
// MXdlPerWave
4
,
// NXdlPerWave
1
,
// Gemm1NXdlPerWave
1
,
// DropoutStep
S
<
4
,
64
,
1
>
,
// ABlockTransfer
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
// BBlockTransfer
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
4
,
S
<
16
,
16
,
1
>
,
// B1BlockTransfer
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
2
,
false
,
1
,
// CShuffleMXdlPerWavePerShuffle
1
,
// CShuffleNXdlPerWavePerShuffle
S
<
1
,
64
,
1
,
4
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
4
,
MaskingSpec
,
// MaskingSpecialization
Deterministic
>
;
#elif(DIM <= 64)
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
ADataType
,
B0DataType
,
B1DataType
,
CDataType
,
GemmDataType
,
ZDataType
,
LSEDataType
,
Acc0BiasDataType
,
Acc1BiasDataType
,
AccDataType
,
CShuffleDataType
,
AElementOp
,
B0ElementOp
,
Acc0ElementOp
,
B1ElementOp
,
CElementOp
,
GemmSpec
,
TensorSpecA
,
TensorSpecB0
,
TensorSpecB1
,
TensorSpecC
,
1
,
256
,
128
,
// MPerBlock
128
,
// NPerBlock
32
,
// KPerBlock
64
,
// Gemm1NPerBlock
32
,
// Gemm1KPerBlock
8
,
// AK1
8
,
// BK1
2
,
// B1K1
32
,
// MPerXDL
32
,
// NPerXDL
1
,
// MXdlPerWave
4
,
// NXdlPerWave
2
,
// Gemm1NXdlPerWave
1
,
// DropoutStep
S
<
4
,
64
,
1
>
,
// ABlockTransfer
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
// BBlockTransfer
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
4
,
S
<
16
,
16
,
1
>
,
// B1BlockTransfer
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
// CShuffleMXdlPerWavePerShuffle
2
,
// CShuffleNXdlPerWavePerShuffle
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
4
,
MaskingSpec
,
// MaskingSpecialization
Deterministic
>
;
#elif(DIM <= 128)
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
ADataType
,
B0DataType
,
B1DataType
,
CDataType
,
GemmDataType
,
ZDataType
,
LSEDataType
,
Acc0BiasDataType
,
Acc1BiasDataType
,
AccDataType
,
CShuffleDataType
,
AElementOp
,
B0ElementOp
,
Acc0ElementOp
,
B1ElementOp
,
CElementOp
,
GemmSpec
,
TensorSpecA
,
TensorSpecB0
,
TensorSpecB1
,
TensorSpecC
,
1
,
256
,
128
,
// MPerBlock
128
,
// NPerBlock
32
,
// KPerBlock
128
,
// Gemm1NPerBlock
32
,
// Gemm1KPerBlock
8
,
// AK1
8
,
// BK1
2
,
// B1K1
32
,
// MPerXDL
32
,
// NPerXDL
1
,
// MXdlPerWave
4
,
// NXdlPerWave
4
,
// Gemm1NXdlPerWave
1
,
// DropoutStep
S
<
4
,
64
,
1
>
,
// ABlockTransfer
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
// BBlockTransfer
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
4
,
S
<
8
,
32
,
1
>
,
// B1BlockTransfer
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
// CShuffleMXdlPerWavePerShuffle
2
,
// CShuffleNXdlPerWavePerShuffle
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
4
,
MaskingSpec
,
// MaskingSpecialization
Deterministic
>
;
#endif
using
DeviceDropoutInstance
=
ck
::
tensor_operation
::
device
::
DeviceBatchedDropout
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
GemmDataType
,
ZDataType
,
GemmDataType
,
GemmSpec
,
TensorSpecA
,
TensorSpecB0
,
TensorSpecB1
,
TensorSpecC
,
256
,
// BlockSize
128
,
// MPerBlock
128
,
// NPerBlock
32
,
// KPerBlock
128
,
// Gemm1NPerBlock
8
,
// AK1
8
,
// BK1
2
,
// B1K1
32
,
// MPerXDL
32
,
// NPerXDL
1
,
// MXdlPerWave
4
>
;
// NXdlPerWave
#include "run_batched_multihead_attention_bias_forward_zcheck.inc"
int
main
(
int
argc
,
char
*
argv
[])
{
return
run
(
argc
,
argv
);
}
example/52_flash_atten_bias/run_batched_multihead_attention_bias_forward_zcheck.inc
0 → 100644
View file @
e568bfdb
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
int
run
(
int
argc
,
char
*
argv
[])
{
using
ck
::
host_common
::
dumpBufferToFile
;
int
init_method
=
1
;
// GEMM shape for A/B0/B1/C
// C_g_m_o = A_g_m_k * B0_g_k_n * B1_g_n_o
ck
::
index_t
M
=
200
;
// 120
ck
::
index_t
N
=
200
;
// 1000
ck
::
index_t
K
=
DIM
;
ck
::
index_t
O
=
DIM
;
// Output shape C[G0, M, G1, O]. Batch dim, outer dim, inner dim must match GEMM shape
// C_g0_g1_m_o = reshape(C_g_m_o, [g0, g1, m, o])
// C_g0_m_g1_o = permute(C_g0_g1_m_o, [0, 2, 1, 3])
ck
::
index_t
G0
=
8
;
ck
::
index_t
G1
=
4
;
bool
input_permute
=
false
;
bool
output_permute
=
true
;
float
p_drop
=
0.1
;
const
unsigned
long
long
seed
=
1
;
const
unsigned
long
long
offset
=
0
;
if
(
argc
==
1
)
{
// use default case
}
else
if
(
argc
==
2
)
{
init_method
=
std
::
stoi
(
argv
[
1
]);
}
else
if
(
argc
==
11
)
{
init_method
=
std
::
stoi
(
argv
[
1
]);
M
=
std
::
stoi
(
argv
[
2
]);
N
=
std
::
stoi
(
argv
[
3
]);
K
=
std
::
stoi
(
argv
[
4
]);
O
=
std
::
stoi
(
argv
[
5
]);
G0
=
std
::
stoi
(
argv
[
6
]);
G1
=
std
::
stoi
(
argv
[
7
]);
p_drop
=
std
::
stof
(
argv
[
8
]);
input_permute
=
std
::
stoi
(
argv
[
9
]);
output_permute
=
std
::
stoi
(
argv
[
10
]);
}
else
{
printf
(
"arg1: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg2 to 7: M, N, K, O, G0, G1
\n
"
);
printf
(
"arg8: drop_prob
\n
"
);
printf
(
"arg9 to 10: input / output permute
\n
"
);
exit
(
0
);
}
float
alpha
=
1.
f
/
std
::
sqrt
(
K
);
std
::
vector
<
ck
::
index_t
>
a_gs_ms_ks_lengths
{
G0
,
G1
,
M
,
K
};
std
::
vector
<
ck
::
index_t
>
a_gs_ms_ks_strides
=
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
K
,
K
,
G1
*
K
,
1
}
// A layout [G0, M, G1, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
K
,
M
*
K
,
K
,
1
};
// A layout [G0, G1, M, K]
std
::
vector
<
ck
::
index_t
>
b0_gs_ns_ks_lengths
{
G0
,
G1
,
N
,
K
};
std
::
vector
<
ck
::
index_t
>
b0_gs_ns_ks_strides
=
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1
*
K
,
K
,
G1
*
K
,
1
}
// B0 layout [G0, N, G1, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
N
*
K
,
N
*
K
,
K
,
1
};
// B0 layout [G0, G1, N, K]
std
::
vector
<
ck
::
index_t
>
b1_gs_os_ns_lengths
{
G0
,
G1
,
O
,
N
};
std
::
vector
<
ck
::
index_t
>
b1_gs_os_ns_strides
=
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1
*
O
,
O
,
1
,
G1
*
O
}
// B1 layout [G0, N, G1, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
N
*
O
,
N
*
O
,
1
,
O
};
// B1 layout [G0, G1, N, O]
std
::
vector
<
ck
::
index_t
>
c_gs_ms_os_lengths
{
G0
,
G1
,
M
,
O
};
std
::
vector
<
ck
::
index_t
>
c_gs_ms_os_strides
=
output_permute
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
O
,
O
,
G1
*
O
,
1
}
// C layout [G0, M, G1, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
O
,
M
*
O
,
O
,
1
};
// C layout [G0, G1, M, O]
std
::
vector
<
ck
::
index_t
>
d_gs_ms_ns_lengths
{
G0
,
G1
,
M
,
N
};
std
::
vector
<
ck
::
index_t
>
d_gs_ms_ns_strides
=
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
N
,
N
,
G1
*
N
,
1
}
// D layout [G0, M, G1, N]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
N
,
M
*
N
,
N
,
1
};
// D layout [G0, G1, M, N]
std
::
vector
<
ck
::
index_t
>
z_gs_ms_ns_lengths
{
G0
,
G1
,
M
,
N
};
std
::
vector
<
ck
::
index_t
>
z_gs_ms_ns_strides
=
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
N
,
N
,
G1
*
N
,
1
}
// Z layout [G0, M, G1, N]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
N
,
M
*
N
,
N
,
1
};
// Z layout [G0, G1, M, N]
std
::
vector
<
ck
::
index_t
>
lse_gs_ms_lengths
{
G0
,
G1
,
M
};
std
::
vector
<
ck
::
index_t
>
lse_gs_ms_strides
=
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
,
M
,
1
};
// LSE layout [G0, G1, M]
Tensor
<
ADataType
>
a_gs_ms_ks
(
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
);
Tensor
<
B0DataType
>
b0_gs_ns_ks
(
b0_gs_ns_ks_lengths
,
b0_gs_ns_ks_strides
);
Tensor
<
B1DataType
>
b1_gs_os_ns
(
b1_gs_os_ns_lengths
,
b1_gs_os_ns_strides
);
Tensor
<
CDataType
>
c_gs_ms_os_host_result
(
c_gs_ms_os_lengths
,
c_gs_ms_os_strides
);
Tensor
<
CDataType
>
c_gs_ms_os_device_result
(
c_gs_ms_os_lengths
,
c_gs_ms_os_strides
);
Tensor
<
Acc0BiasDataType
>
d_gs_ms_ns
(
d_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
);
Tensor
<
ZDataType
>
z_gs_ms_ns
(
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
);
Tensor
<
LSEDataType
>
lse_gs_ms_host_result
(
lse_gs_ms_lengths
,
lse_gs_ms_strides
);
Tensor
<
LSEDataType
>
lse_gs_ms_device_result
(
lse_gs_ms_lengths
,
lse_gs_ms_strides
);
Tensor
<
ZDataType
>
z_gs_ms_ns_2
(
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
);
std
::
cout
<<
"a_gs_ms_ks: "
<<
a_gs_ms_ks
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"b0_gs_ns_ks: "
<<
b0_gs_ns_ks
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"b1_gs_os_ns: "
<<
b1_gs_os_ns
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"c_gs_ms_os: "
<<
c_gs_ms_os_host_result
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"z_gs_ms_ns: "
<<
z_gs_ms_ns
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"lse_gs_ms_os: "
<<
lse_gs_ms_host_result
.
mDesc
<<
std
::
endl
;
z_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
ZDataType
>
{
0
});
switch
(
init_method
)
{
case
0
:
break
;
case
1
:
a_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
ADataType
>
{
-
2
,
2
});
b0_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
B0DataType
>
{
-
2
,
2
});
b1_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_2
<
B1DataType
>
{
-
2
,
2
});
d_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_2
<
Acc0BiasDataType
>
{
-
1
,
1
});
break
;
case
2
:
a_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_3
<
ADataType
>
{
0.0
,
1.0
});
b0_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_3
<
B0DataType
>
{
0.0
,
1.0
});
b1_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_3
<
B1DataType
>
{
-
0.5
,
0.5
});
d_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_3
<
Acc0BiasDataType
>
{
-
0.5
,
0.5
});
break
;
case
3
:
a_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
ADataType
>
{
-
2
,
2
});
b0_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
B0DataType
>
{});
b1_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
B1DataType
>
{});
d_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
Acc0BiasDataType
>
{
1
});
break
;
default
:
a_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
2
>
{});
b0_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
B0DataType
>
{});
b1_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
B1DataType
>
{});
d_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
Acc0BiasDataType
>
{
1
});
}
DeviceMem
a_device_buf
(
sizeof
(
ADataType
)
*
a_gs_ms_ks
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
b0_device_buf
(
sizeof
(
B0DataType
)
*
b0_gs_ns_ks
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
b1_device_buf
(
sizeof
(
B1DataType
)
*
b1_gs_os_ns
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
c_device_buf
(
sizeof
(
CDataType
)
*
c_gs_ms_os_device_result
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
d_device_buf
(
sizeof
(
Acc0BiasDataType
)
*
d_gs_ms_ns
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
z_device_buf
(
sizeof
(
ZDataType
)
*
z_gs_ms_ns
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
lse_device_buf
(
sizeof
(
LSEDataType
)
*
lse_gs_ms_device_result
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
z_device_buf_2
(
sizeof
(
ZDataType
)
*
z_gs_ms_ns_2
.
mDesc
.
GetElementSpaceSize
());
a_device_buf
.
ToDevice
(
a_gs_ms_ks
.
mData
.
data
());
b0_device_buf
.
ToDevice
(
b0_gs_ns_ks
.
mData
.
data
());
b1_device_buf
.
ToDevice
(
b1_gs_os_ns
.
mData
.
data
());
d_device_buf
.
ToDevice
(
d_gs_ms_ns
.
mData
.
data
());
auto
a_element_op
=
AElementOp
{};
auto
b0_element_op
=
B0ElementOp
{};
auto
acc0_element_op
=
Acc0ElementOp
{
alpha
};
auto
b1_element_op
=
B1ElementOp
{};
auto
c_element_op
=
CElementOp
{};
// do GEMM
auto
gemm_op
=
DeviceGemmInstance
{};
auto
gemm_invoker
=
gemm_op
.
MakeInvoker
();
// run for storing z tensor
auto
gemm_arg
=
gemm_op
.
MakeArgument
(
static_cast
<
ADataType
*>
(
a_device_buf
.
GetDeviceBuffer
()),
static_cast
<
B0DataType
*>
(
b0_device_buf
.
GetDeviceBuffer
()),
static_cast
<
B1DataType
*>
(
b1_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CDataType
*>
(
c_device_buf
.
GetDeviceBuffer
()),
static_cast
<
ZDataType
*>
(
z_device_buf
.
GetDeviceBuffer
()),
static_cast
<
LSEDataType
*>
(
lse_device_buf
.
GetDeviceBuffer
()),
static_cast
<
Acc0BiasDataType
*>
(
d_device_buf
.
GetDeviceBuffer
()),
nullptr
,
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
,
b0_gs_ns_ks_lengths
,
b0_gs_ns_ks_strides
,
b1_gs_os_ns_lengths
,
b1_gs_os_ns_strides
,
c_gs_ms_os_lengths
,
c_gs_ms_os_strides
,
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
,
lse_gs_ms_lengths
,
d_gs_ms_ns_lengths
,
d_gs_ms_ns_strides
,
{},
{},
a_element_op
,
b0_element_op
,
acc0_element_op
,
b1_element_op
,
c_element_op
,
p_drop
,
// dropout ratio
{
seed
,
offset
});
// dropout random seed and offset, offset should
// be at least the number of elements on a thread
if
(
!
gemm_op
.
IsSupportedArgument
(
gemm_arg
))
{
std
::
cout
<<
gemm_op
.
GetTypeString
()
<<
" does not support this problem"
<<
std
::
endl
;
return
0
;
}
c_device_buf
.
SetZero
();
lse_device_buf
.
SetZero
();
gemm_invoker
.
Run
(
gemm_arg
,
StreamConfig
{
nullptr
,
false
});
z_device_buf
.
FromDevice
(
z_gs_ms_ns
.
mData
.
data
());
dumpBufferToFile
(
"forward_z.dat"
,
z_gs_ms_ns
.
mData
.
data
(),
z_gs_ms_ns
.
mData
.
size
());
// do Dropout
auto
dropout_op
=
DeviceDropoutInstance
();
auto
dropout_invoker
=
dropout_op
.
MakeInvoker
();
auto
dropout_arg
=
dropout_op
.
MakeArgument
(
static_cast
<
ZDataType
*>
(
z_device_buf_2
.
GetDeviceBuffer
()),
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
,
b0_gs_ns_ks_lengths
,
b0_gs_ns_ks_strides
,
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
,
{
seed
,
offset
});
dropout_invoker
.
Run
(
dropout_arg
,
StreamConfig
{
nullptr
,
false
});
z_device_buf_2
.
FromDevice
(
z_gs_ms_ns_2
.
mData
.
data
());
dumpBufferToFile
(
"canonic_z.dat"
,
z_gs_ms_ns_2
.
mData
.
data
(),
z_gs_ms_ns_2
.
mData
.
size
());
return
ck
::
utils
::
check_integer_err
(
z_gs_ms_ns
.
mData
,
z_gs_ms_ns_2
.
mData
,
1.0e-5
);
}
include/ck/tensor_operation/gpu/device/impl/device_batched_dropout.hpp
View file @
e568bfdb
...
@@ -25,7 +25,8 @@ namespace ck {
...
@@ -25,7 +25,8 @@ namespace ck {
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
device
{
namespace
device
{
template
<
typename
ZDataType
,
template
<
typename
GridwiseDropout_
,
typename
ZDataType
,
typename
AGridDesc_AK0_M_AK1
,
typename
AGridDesc_AK0_M_AK1
,
typename
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
,
typename
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
,
typename
Block2CTileMap
,
typename
Block2CTileMap
,
...
@@ -60,7 +61,7 @@ __global__ void
...
@@ -60,7 +61,7 @@ __global__ void
const
index_t
z_random_matrix_offset
=
g_idx
*
raw_m_padded
*
raw_n_padded
;
const
index_t
z_random_matrix_offset
=
g_idx
*
raw_m_padded
*
raw_n_padded
;
GridwiseDropout
::
Run
(
z_matrix_ptr
,
GridwiseDropout
_
::
Run
(
z_matrix_ptr
,
a_grid_desc_ak0_m_ak1
,
a_grid_desc_ak0_m_ak1
,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
block_2_ctile_map
,
block_2_ctile_map
,
...
@@ -109,7 +110,7 @@ template <index_t NumDimG,
...
@@ -109,7 +110,7 @@ template <index_t NumDimG,
index_t
NPerXDL
,
index_t
NPerXDL
,
index_t
MXdlPerWave
,
index_t
MXdlPerWave
,
index_t
NXdlPerWave
>
index_t
NXdlPerWave
>
struct
DeviceBatchedDropout
:
public
BaseOperator
struct
DeviceBatchedDropout
:
public
ck
::
tensor_operation
::
device
::
BaseOperator
{
{
static_assert
(
NumDimG
>
0
&&
NumDimM
>
0
&&
NumDimN
>
0
&&
NumDimK
>
0
&&
NumDimO
>
0
,
static_assert
(
NumDimG
>
0
&&
NumDimM
>
0
&&
NumDimN
>
0
&&
NumDimK
>
0
&&
NumDimO
>
0
,
"Number of dimension must be greater than 0"
);
"Number of dimension must be greater than 0"
);
...
@@ -220,7 +221,7 @@ struct DeviceBatchedDropout : public BaseOperator
...
@@ -220,7 +221,7 @@ struct DeviceBatchedDropout : public BaseOperator
raw_lengths_mz_nz_kz_gemm1nz_
{
a_gs_ms_ks_lengths
[
NumDimG
+
NumDimM
-
1
],
raw_lengths_mz_nz_kz_gemm1nz_
{
a_gs_ms_ks_lengths
[
NumDimG
+
NumDimM
-
1
],
b_gs_ns_ks_lengths
[
NumDimG
+
NumDimN
-
1
],
b_gs_ns_ks_lengths
[
NumDimG
+
NumDimN
-
1
],
b_gs_ns_ks_lengths
[
NumDimG
+
NumDimN
+
NumDimK
-
1
]},
b_gs_ns_ks_lengths
[
NumDimG
+
NumDimN
+
NumDimK
-
1
]},
batch_count_
{
c
_grid_desc_g_m_n_
.
GetLength
(
I0
)}
batch_count_
{
z
_grid_desc_g_m_n_
.
GetLength
(
I0
)}
{
{
compute_base_ptr_of_batch_
=
ComputeBasePtrOfStridedBatch
(
z_grid_desc_g_m_n_
);
compute_base_ptr_of_batch_
=
ComputeBasePtrOfStridedBatch
(
z_grid_desc_g_m_n_
);
...
@@ -228,7 +229,7 @@ struct DeviceBatchedDropout : public BaseOperator
...
@@ -228,7 +229,7 @@ struct DeviceBatchedDropout : public BaseOperator
seed_
=
std
::
get
<
0
>
(
seeds
);
seed_
=
std
::
get
<
0
>
(
seeds
);
offset_
=
std
::
get
<
1
>
(
seeds
);
offset_
=
std
::
get
<
1
>
(
seeds
);
c
_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_
=
z
_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_
=
GridwiseDropout
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
(
GridwiseDropout
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
(
z_grid_desc_m_n_
);
z_grid_desc_m_n_
);
// Print();
// Print();
...
@@ -237,14 +238,6 @@ struct DeviceBatchedDropout : public BaseOperator
...
@@ -237,14 +238,6 @@ struct DeviceBatchedDropout : public BaseOperator
n_raw_padded_
=
GridwiseDropout
::
GetPaddedSize
(
raw_lengths_mz_nz_kz_gemm1nz_
[
1
]);
n_raw_padded_
=
GridwiseDropout
::
GetPaddedSize
(
raw_lengths_mz_nz_kz_gemm1nz_
[
1
]);
}
}
void
Print
()
const
{
std
::
cout
<<
"a_grid_desc_g_m_k_: "
<<
a_grid_desc_g_m_k_
.
GetLength
(
I0
)
<<
", "
<<
a_grid_desc_g_m_k_
.
GetLength
(
I1
)
<<
", "
<<
a_grid_desc_g_m_k_
.
GetLength
(
I2
)
<<
'\n'
;
// a_grid_desc_g_m_k_.Print();
}
// pointers
// pointers
ZDataType
*
p_z_grid_
;
ZDataType
*
p_z_grid_
;
...
@@ -257,7 +250,7 @@ struct DeviceBatchedDropout : public BaseOperator
...
@@ -257,7 +250,7 @@ struct DeviceBatchedDropout : public BaseOperator
ZGridDesc_G_M_N
z_grid_desc_g_m_n_
;
ZGridDesc_G_M_N
z_grid_desc_g_m_n_
;
typename
GridwiseDropout
::
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
typename
GridwiseDropout
::
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
c
_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_
;
z
_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_
;
// block-to-c-tile map
// block-to-c-tile map
typename
GridwiseDropout
::
DefaultBlock2CTileMap
block_2_ctile_map_
;
typename
GridwiseDropout
::
DefaultBlock2CTileMap
block_2_ctile_map_
;
...
@@ -294,6 +287,7 @@ struct DeviceBatchedDropout : public BaseOperator
...
@@ -294,6 +287,7 @@ struct DeviceBatchedDropout : public BaseOperator
auto
launch_kernel
=
[
&
]()
{
auto
launch_kernel
=
[
&
]()
{
const
auto
kernel
=
kernel_batched_dropout
<
const
auto
kernel
=
kernel_batched_dropout
<
GridwiseDropout
,
ZDataType
,
ZDataType
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
typename
GridwiseDropout
::
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
,
typename
GridwiseDropout
::
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
,
...
@@ -307,7 +301,7 @@ struct DeviceBatchedDropout : public BaseOperator
...
@@ -307,7 +301,7 @@ struct DeviceBatchedDropout : public BaseOperator
0
,
0
,
arg
.
p_z_grid_
,
arg
.
p_z_grid_
,
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
c
_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_
,
arg
.
z
_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_
,
arg
.
block_2_ctile_map_
,
arg
.
block_2_ctile_map_
,
arg
.
batch_count_
,
arg
.
batch_count_
,
arg
.
compute_base_ptr_of_batch_
,
arg
.
compute_base_ptr_of_batch_
,
...
@@ -336,9 +330,7 @@ struct DeviceBatchedDropout : public BaseOperator
...
@@ -336,9 +330,7 @@ struct DeviceBatchedDropout : public BaseOperator
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
{
#if DEBUG_LOG
(
void
)
arg
;
arg
.
Print
();
#endif
if
(
!
(
ck
::
get_device_name
()
==
"gfx908"
||
ck
::
get_device_name
()
==
"gfx90a"
||
if
(
!
(
ck
::
get_device_name
()
==
"gfx908"
||
ck
::
get_device_name
()
==
"gfx90a"
||
ck
::
get_device_name
()
==
"gfx940"
||
ck
::
get_device_name
()
==
"gfx941"
||
ck
::
get_device_name
()
==
"gfx940"
||
ck
::
get_device_name
()
==
"gfx941"
||
...
@@ -425,8 +417,7 @@ struct DeviceBatchedDropout : public BaseOperator
...
@@ -425,8 +417,7 @@ struct DeviceBatchedDropout : public BaseOperator
<<
"ASpec"
<<
getTensorSpecializationString
(
ASpec
)
<<
", "
<<
"ASpec"
<<
getTensorSpecializationString
(
ASpec
)
<<
", "
<<
"B0Spec"
<<
getTensorSpecializationString
(
BSpec
)
<<
", "
<<
"B0Spec"
<<
getTensorSpecializationString
(
BSpec
)
<<
", "
<<
"B1Spec"
<<
getTensorSpecializationString
(
B1Spec
)
<<
", "
<<
"B1Spec"
<<
getTensorSpecializationString
(
B1Spec
)
<<
", "
<<
"CSpec"
<<
getTensorSpecializationString
(
CSpec
)
<<
", "
<<
"CSpec"
<<
getTensorSpecializationString
(
CSpec
);
<<
getMaskingSpecializationString
(
MaskingSpec
)
<<
">"
;
// clang-format on
// clang-format on
return
str
.
str
();
return
str
.
str
();
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_dropout.hpp
View file @
e568bfdb
...
@@ -238,6 +238,14 @@ struct GridwiseBatchedDropout
...
@@ -238,6 +238,14 @@ struct GridwiseBatchedDropout
constexpr
auto
m4
=
thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetLength
(
I6
);
constexpr
auto
m4
=
thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetLength
(
I6
);
constexpr
auto
n2
=
thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetLength
(
I7
);
constexpr
auto
n2
=
thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetLength
(
I7
);
// only used for BlockwiseDropout
constexpr
auto
thread_slice_desc_m_n
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
m0
*
m1
*
m2
*
m3
*
m4
,
n0
*
n1
*
n2
));
// only used for providing ApplyDropoutAttnBwdSaveZ
auto
blockwise_dropout
=
BlockwiseDropout
<
FloatGemmAcc
,
decltype
(
thread_slice_desc_m_n
)
>
{
static_cast
<
unsigned
short
>
(
0.8
f
*
65535.
f
),
static_cast
<
FloatGemmAcc
>
(
1.0
f
/
0.8
f
)};
//
//
// z vgpr copy to global
// z vgpr copy to global
//
//
...
@@ -332,6 +340,9 @@ struct GridwiseBatchedDropout
...
@@ -332,6 +340,9 @@ struct GridwiseBatchedDropout
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
,
4
,
5
,
6
>
{},
Sequence
<
1
,
3
,
7
>
{}));
make_tuple
(
Sequence
<
0
,
2
,
4
,
5
,
6
>
{},
Sequence
<
1
,
3
,
7
>
{}));
auto
acc0_thread_origin
=
s_blockwise_gemm
.
CalculateCThreadOriginDataIndex8D
(
Number
<
0
>
{},
Number
<
0
>
{},
Number
<
0
>
{},
Number
<
0
>
{});
// gemm0 M loop
// gemm0 M loop
index_t
gemm0_m_block_outer_index
=
num_gemm0_m_block_outer_loop
-
1
;
index_t
gemm0_m_block_outer_index
=
num_gemm0_m_block_outer_loop
-
1
;
...
@@ -372,5 +383,6 @@ struct GridwiseBatchedDropout
...
@@ -372,5 +383,6 @@ struct GridwiseBatchedDropout
make_multi_index
(
-
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
));
make_multi_index
(
-
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
));
}
while
(
0
<
gemm0_m_block_outer_index
--
);
// end j loop
}
while
(
0
<
gemm0_m_block_outer_index
--
);
// end j loop
};
};
};
}
// namespace ck
}
// namespace ck
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment