Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
dc8e0148
Commit
dc8e0148
authored
Jun 06, 2023
by
guangzlu
Browse files
added dropout shuffle for attn fwd, bwd v4 can pass now
parent
78c1482a
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
345 additions
and
160 deletions
+345
-160
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_train_v5.cpp
...ale_softmax_gemm/batched_multihead_attention_train_v5.cpp
+3
-3
include/ck/tensor_operation/gpu/block/blockwise_dropout.hpp
include/ck/tensor_operation/gpu/block/blockwise_dropout.hpp
+60
-65
include/ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_forward_xdl_cshuffle.hpp
...vice_batched_multihead_attention_forward_xdl_cshuffle.hpp
+79
-63
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_forward_xdl_cshuffle.hpp
...wise_batched_multihead_attention_forward_xdl_cshuffle.hpp
+203
-29
No files found.
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_train_v5.cpp
View file @
dc8e0148
...
@@ -78,7 +78,7 @@ using GemmDataType = F16;
...
@@ -78,7 +78,7 @@ using GemmDataType = F16;
using
AccDataType
=
F32
;
using
AccDataType
=
F32
;
using
ShuffleDataType
=
F32
;
using
ShuffleDataType
=
F32
;
using
LSEDataType
=
F32
;
using
LSEDataType
=
F32
;
using
ZDataType
=
INT32
;
// INT32
using
ZDataType
=
U16
;
// INT32
using
Acc0BiasDataType
=
ck
::
Tuple
<>
;
using
Acc0BiasDataType
=
ck
::
Tuple
<>
;
using
Acc1BiasDataType
=
ck
::
Tuple
<>
;
using
Acc1BiasDataType
=
ck
::
Tuple
<>
;
...
@@ -104,7 +104,7 @@ static constexpr auto TensorSpecQ = ck::tensor_operation::device::TensorSpecia
...
@@ -104,7 +104,7 @@ static constexpr auto TensorSpecQ = ck::tensor_operation::device::TensorSpecia
static
constexpr
auto
TensorSpecK
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecK
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecV
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecV
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecY
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecY
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
bool
Deterministic
=
tru
e
;
static
constexpr
bool
Deterministic
=
fals
e
;
// DIM should be a multiple of 8.
// DIM should be a multiple of 8.
// If DIM <= 32 , ues prototype1 1st template.
// If DIM <= 32 , ues prototype1 1st template.
...
@@ -716,7 +716,7 @@ int run(int argc, char* argv[])
...
@@ -716,7 +716,7 @@ int run(int argc, char* argv[])
ck
::
index_t
K
=
DIM
;
ck
::
index_t
K
=
DIM
;
ck
::
index_t
O
=
DIM
;
ck
::
index_t
O
=
DIM
;
ck
::
index_t
G0
=
1
;
// 54
ck
::
index_t
G0
=
1
;
// 54
ck
::
index_t
G1
=
2
;
// 16
ck
::
index_t
G1
=
1
;
// 16
bool
input_permute
=
false
;
bool
input_permute
=
false
;
bool
output_permute
=
false
;
bool
output_permute
=
false
;
...
...
include/ck/tensor_operation/gpu/block/blockwise_dropout.hpp
View file @
dc8e0148
...
@@ -123,8 +123,9 @@ struct BlockwiseDropout
...
@@ -123,8 +123,9 @@ struct BlockwiseDropout
}
}
template
<
typename
CThreadBuffer
,
bool
using_sign_bit
=
false
>
template
<
typename
CThreadBuffer
,
bool
using_sign_bit
=
false
>
__host__
__device__
void
__host__
__device__
void
ApplyDropoutAttnFwd
(
CThreadBuffer
&
in_thread_buf
,
ApplyDropout_v1r1
(
CThreadBuffer
&
in_thread_buf
,
ck
::
philox
&
ph
,
index_t
element_global_1d_id
)
//
ck
::
philox
&
ph
,
index_t
element_global_1d_id
)
//
{
{
auto
execute_dropout
=
[
&
](
bool
keep
,
DataType
val
)
{
auto
execute_dropout
=
[
&
](
bool
keep
,
DataType
val
)
{
...
@@ -161,7 +162,7 @@ struct BlockwiseDropout
...
@@ -161,7 +162,7 @@ struct BlockwiseDropout
__host__
__device__
void
ApplyDropoutAttnBwd
(
CThreadBuffer
&
in_thread_buf
,
__host__
__device__
void
ApplyDropoutAttnBwd
(
CThreadBuffer
&
in_thread_buf
,
ck
::
philox
&
ph
,
ck
::
philox
&
ph
,
index_t
element_global_1d_id
,
index_t
element_global_1d_id
,
index_t
MRaw
)
//
index_t
MRaw
)
{
{
auto
execute_dropout
=
[
&
](
bool
keep
,
DataType
val
)
{
auto
execute_dropout
=
[
&
](
bool
keep
,
DataType
val
)
{
...
@@ -195,10 +196,11 @@ struct BlockwiseDropout
...
@@ -195,10 +196,11 @@ struct BlockwiseDropout
}
}
template
<
typename
CThreadBuffer
,
typename
ZThreadBuffer
,
bool
using_sign_bit
=
false
>
template
<
typename
CThreadBuffer
,
typename
ZThreadBuffer
,
bool
using_sign_bit
=
false
>
__host__
__device__
void
ApplyDropout
_v1r2
(
CThreadBuffer
&
in_thread_buf
,
__host__
__device__
void
ApplyDropout
AttnBwdSaveZ
(
CThreadBuffer
&
in_thread_buf
,
ck
::
philox
&
ph
,
ck
::
philox
&
ph
,
index_t
element_global_1d_id
,
index_t
element_global_1d_id
,
ZThreadBuffer
&
z_thread_buf
)
ZThreadBuffer
&
z_thread_buf
,
index_t
MRaw
)
{
{
auto
execute_dropout
=
[
&
](
bool
keep
,
DataType
val
)
{
auto
execute_dropout
=
[
&
](
bool
keep
,
DataType
val
)
{
...
@@ -215,17 +217,17 @@ struct BlockwiseDropout
...
@@ -215,17 +217,17 @@ struct BlockwiseDropout
ushort
tmp
[
tmp_size
];
ushort
tmp
[
tmp_size
];
for
(
int
i
=
0
;
i
<
philox_calls
;
i
++
)
for
(
int
i
=
0
;
i
<
philox_calls
;
i
++
)
{
{
ph
.
get_random_4x16
((
tmp
+
i
*
4
),
element_global_1d_id
+
i
*
8
);
ph
.
get_random_4x16
((
tmp
+
i
*
4
),
element_global_1d_id
+
i
*
8
*
MRaw
);
}
}
ushort
tmp_id
[
tmp_size
];
//
ushort tmp_id[tmp_size];
for
(
int
i
=
0
;
i
<
philox_calls
;
i
++
)
//
for(int i = 0; i < philox_calls; i++)
{
//
{
for
(
int
j
=
0
;
j
<
4
;
j
++
)
//
for(int j = 0; j < 4; j++)
{
//
{
tmp_id
[
i
*
4
+
j
]
=
element_global_1d_id
+
i
*
8
;
//
tmp_id[i * 4 + j] = element_global_1d_id + i * 8
* MRaw
;
}
//
}
}
//
}
block_sync_lds
();
block_sync_lds
();
...
@@ -235,18 +237,15 @@ struct BlockwiseDropout
...
@@ -235,18 +237,15 @@ struct BlockwiseDropout
auto
offset
=
Number
<
ThreadSliceDesc_M_K
{}.
CalculateOffset
(
make_tuple
(
iM
,
iK
))
>
{};
auto
offset
=
Number
<
ThreadSliceDesc_M_K
{}.
CalculateOffset
(
make_tuple
(
iM
,
iK
))
>
{};
in_thread_buf
(
offset
)
=
in_thread_buf
(
offset
)
=
execute_dropout
(
tmp
[
tmp_index
]
<=
p_dropout_16bits
,
in_thread_buf
(
offset
));
execute_dropout
(
tmp
[
tmp_index
]
<=
p_dropout_16bits
,
in_thread_buf
(
offset
));
z_thread_buf
(
offset
)
=
tmp
_id
[
tmp_index
];
z_thread_buf
(
offset
)
=
tmp
[
tmp_index
];
tmp_index
=
tmp_index
+
1
;
tmp_index
=
tmp_index
+
1
;
});
});
});
});
}
}
template
<
typename
CThreadBuffer
,
typename
ZThreadBuffer
,
bool
using_sign_bit
=
false
>
template
<
typename
CThreadBuffer
,
typename
ZThreadBuffer
,
bool
using_sign_bit
=
false
>
__host__
__device__
void
ApplyDropoutAttnBwdSaveZ
(
CThreadBuffer
&
in_thread_buf
,
__host__
__device__
void
ApplyDropoutWithZ
(
CThreadBuffer
&
in_thread_buf
,
ck
::
philox
&
ph
,
ZThreadBuffer
&
z_thread_buf
)
index_t
element_global_1d_id
,
ZThreadBuffer
&
z_thread_buf
,
index_t
MRaw
)
{
{
auto
execute_dropout
=
[
&
](
bool
keep
,
DataType
val
)
{
auto
execute_dropout
=
[
&
](
bool
keep
,
DataType
val
)
{
...
@@ -256,6 +255,26 @@ struct BlockwiseDropout
...
@@ -256,6 +255,26 @@ struct BlockwiseDropout
return
keep
?
val
*
p_dropout_rescale
:
float
(
0
);
return
keep
?
val
*
p_dropout_rescale
:
float
(
0
);
};
};
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
iM
)
{
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
iK
)
{
auto
offset
=
Number
<
ThreadSliceDesc_M_K
{}.
CalculateOffset
(
make_tuple
(
iM
,
iK
))
>
{};
in_thread_buf
(
offset
)
=
execute_dropout
(
z_thread_buf
(
offset
)
<=
p_dropout_16bits
,
in_thread_buf
(
offset
));
});
});
}
// get raw z matrix with random number for shuffle
template
<
typename
ZThreadBuffer
>
__host__
__device__
void
GenerateZMatrixAttnFwd
(
ck
::
philox
&
ph
,
index_t
element_global_1d_id
,
ZThreadBuffer
&
z_thread_buf
)
{
// if(get_thread_global_1d_id() == 0){
// printf("MRepeat & KRepeat is %d , %d . \n", MRepeat, KRepeat);
// }
constexpr
int
tmp_size
=
MRepeat
*
KRepeat
;
constexpr
int
tmp_size
=
MRepeat
*
KRepeat
;
int
philox_calls
=
tmp_size
/
4
;
int
philox_calls
=
tmp_size
/
4
;
...
@@ -263,17 +282,17 @@ struct BlockwiseDropout
...
@@ -263,17 +282,17 @@ struct BlockwiseDropout
ushort
tmp
[
tmp_size
];
ushort
tmp
[
tmp_size
];
for
(
int
i
=
0
;
i
<
philox_calls
;
i
++
)
for
(
int
i
=
0
;
i
<
philox_calls
;
i
++
)
{
{
ph
.
get_random_4x16
((
tmp
+
i
*
4
),
element_global_1d_id
+
i
*
8
*
MRaw
);
ph
.
get_random_4x16
((
tmp
+
i
*
4
),
element_global_1d_id
+
i
*
8
);
}
}
ushort
tmp_id
[
tmp_size
];
//
ushort tmp_id[tmp_size];
for
(
int
i
=
0
;
i
<
philox_calls
;
i
++
)
//
for(int i = 0; i < philox_calls; i++)
{
//
{
for
(
int
j
=
0
;
j
<
4
;
j
++
)
//
for(int j = 0; j < 4; j++)
{
//
{
tmp_id
[
i
*
4
+
j
]
=
element_global_1d_id
+
i
*
8
*
MRaw
;
//
tmp_id[i * 4 + j] = element_global_1d_id + i * 8;
}
//
}
}
//
}
block_sync_lds
();
block_sync_lds
();
...
@@ -281,36 +300,12 @@ struct BlockwiseDropout
...
@@ -281,36 +300,12 @@ struct BlockwiseDropout
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
iM
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
iM
)
{
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
iK
)
{
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
iK
)
{
auto
offset
=
Number
<
ThreadSliceDesc_M_K
{}.
CalculateOffset
(
make_tuple
(
iM
,
iK
))
>
{};
auto
offset
=
Number
<
ThreadSliceDesc_M_K
{}.
CalculateOffset
(
make_tuple
(
iM
,
iK
))
>
{};
in_thread_buf
(
offset
)
=
z_thread_buf
(
offset
)
=
tmp
[
tmp_index
];
execute_dropout
(
tmp
[
tmp_index
]
<=
p_dropout_16bits
,
in_thread_buf
(
offset
));
z_thread_buf
(
offset
)
=
tmp_id
[
tmp_index
];
tmp_index
=
tmp_index
+
1
;
tmp_index
=
tmp_index
+
1
;
});
});
});
});
}
}
template
<
typename
CThreadBuffer
,
typename
ZThreadBuffer
,
bool
using_sign_bit
=
false
>
__host__
__device__
void
ApplyDropout_v2
(
CThreadBuffer
&
in_thread_buf
,
ZThreadBuffer
&
z_thread_buf
)
{
auto
execute_dropout
=
[
&
](
bool
keep
,
DataType
val
)
{
if
constexpr
(
using_sign_bit
)
return
keep
?
val
:
-
val
;
else
return
keep
?
val
*
p_dropout_rescale
:
float
(
0
);
};
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
iM
)
{
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
iK
)
{
auto
offset
=
Number
<
ThreadSliceDesc_M_K
{}.
CalculateOffset
(
make_tuple
(
iM
,
iK
))
>
{};
in_thread_buf
(
offset
)
=
execute_dropout
(
z_thread_buf
(
offset
)
<=
p_dropout_16bits
,
in_thread_buf
(
offset
));
});
});
}
// get raw z matrix with random number for shuffle
template
<
typename
ZThreadBuffer
>
template
<
typename
ZThreadBuffer
>
__host__
__device__
void
GenerateZMatrix
(
ck
::
philox
&
ph
,
__host__
__device__
void
GenerateZMatrix
(
ck
::
philox
&
ph
,
index_t
element_global_1d_id
,
index_t
element_global_1d_id
,
...
@@ -332,14 +327,14 @@ struct BlockwiseDropout
...
@@ -332,14 +327,14 @@ struct BlockwiseDropout
ph
.
get_random_4x16
((
tmp
+
i
*
4
),
element_global_1d_id
+
i
*
8
*
MRaw
);
ph
.
get_random_4x16
((
tmp
+
i
*
4
),
element_global_1d_id
+
i
*
8
*
MRaw
);
}
}
ushort
tmp_id
[
tmp_size
];
//
ushort tmp_id[tmp_size];
for
(
int
i
=
0
;
i
<
philox_calls
;
i
++
)
//
for(int i = 0; i < philox_calls; i++)
{
//
{
for
(
int
j
=
0
;
j
<
4
;
j
++
)
//
for(int j = 0; j < 4; j++)
{
//
{
tmp_id
[
i
*
4
+
j
]
=
element_global_1d_id
+
i
*
8
*
MRaw
;
//
tmp_id[i * 4 + j] = element_global_1d_id + i * 8 * MRaw;
}
//
}
}
//
}
block_sync_lds
();
block_sync_lds
();
...
@@ -347,7 +342,7 @@ struct BlockwiseDropout
...
@@ -347,7 +342,7 @@ struct BlockwiseDropout
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
iM
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
iM
)
{
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
iK
)
{
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
iK
)
{
auto
offset
=
Number
<
ThreadSliceDesc_M_K
{}.
CalculateOffset
(
make_tuple
(
iM
,
iK
))
>
{};
auto
offset
=
Number
<
ThreadSliceDesc_M_K
{}.
CalculateOffset
(
make_tuple
(
iM
,
iK
))
>
{};
z_thread_buf
(
offset
)
=
tmp
_id
[
tmp_index
];
z_thread_buf
(
offset
)
=
tmp
[
tmp_index
];
tmp_index
=
tmp_index
+
1
;
tmp_index
=
tmp_index
+
1
;
});
});
});
});
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_forward_xdl_cshuffle.hpp
View file @
dc8e0148
...
@@ -39,6 +39,7 @@ template <typename GridwiseGemm,
...
@@ -39,6 +39,7 @@ template <typename GridwiseGemm,
typename
B1GridDesc_BK0_N_BK1
,
typename
B1GridDesc_BK0_N_BK1
,
typename
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
,
typename
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
,
typename
ZShuffleGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_M4_N4_N5
,
typename
LSEGridDescriptor_M
,
typename
LSEGridDescriptor_M
,
typename
Block2CTileMap
,
typename
Block2CTileMap
,
typename
ComputeBasePtrOfStridedBatch
,
typename
ComputeBasePtrOfStridedBatch
,
...
@@ -70,6 +71,8 @@ __global__ void
...
@@ -70,6 +71,8 @@ __global__ void
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
const
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
const
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
const
ZShuffleGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_M4_N4_N5
z_grid_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5
,
const
LSEGridDescriptor_M
lse_grid_desc_m
,
const
LSEGridDescriptor_M
lse_grid_desc_m
,
const
Block2CTileMap
block_2_ctile_map
,
const
Block2CTileMap
block_2_ctile_map
,
const
index_t
batch_count
,
const
index_t
batch_count
,
...
@@ -127,6 +130,7 @@ __global__ void
...
@@ -127,6 +130,7 @@ __global__ void
b1_grid_desc_bk0_n_bk1
,
b1_grid_desc_bk0_n_bk1
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
z_grid_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5
,
lse_grid_desc_m
,
lse_grid_desc_m
,
block_2_ctile_map
,
block_2_ctile_map
,
c0_matrix_mask
,
c0_matrix_mask
,
...
@@ -159,6 +163,7 @@ __global__ void
...
@@ -159,6 +163,7 @@ __global__ void
b1_grid_desc_bk0_n_bk1
,
b1_grid_desc_bk0_n_bk1
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
z_grid_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5
,
lse_grid_desc_m
,
lse_grid_desc_m
,
block_2_ctile_map
,
block_2_ctile_map
,
c0_matrix_mask
,
c0_matrix_mask
,
...
@@ -648,6 +653,10 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -648,6 +653,10 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_
=
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_
=
GridwiseGemm
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
(
z_grid_desc_m_n_
);
GridwiseGemm
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
(
z_grid_desc_m_n_
);
z_grid_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5_
=
GridwiseGemm
::
MakeCShuffleGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_M4_N4_N5
(
z_grid_desc_m_n_
);
if
(
p_lse_grid
==
nullptr
)
if
(
p_lse_grid
==
nullptr
)
{
{
is_lse_storing_
=
false
;
is_lse_storing_
=
false
;
...
@@ -693,9 +702,13 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -693,9 +702,13 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
ZGridDesc_G_M_N
z_grid_desc_g_m_n_
;
ZGridDesc_G_M_N
z_grid_desc_g_m_n_
;
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock_
;
c_grid_desc_mblock_mperblock_nblock_nperblock_
;
typename
GridwiseGemm
::
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
typename
GridwiseGemm
::
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_
;
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_
;
typename
GridwiseGemm
::
ZShuffleGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_M4_N4_N5
z_grid_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5_
;
// block-to-c-tile map
// block-to-c-tile map
typename
GridwiseGemm
::
DefaultBlock2CTileMap
block_2_ctile_map_
;
typename
GridwiseGemm
::
DefaultBlock2CTileMap
block_2_ctile_map_
;
...
@@ -752,8 +765,9 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -752,8 +765,9 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
float
ave_time
=
0
;
float
ave_time
=
0
;
auto
launch_kernel
=
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop_
,
[
&
](
auto
has_main_k_block_loop_
,
auto
is_dropout_
,
auto
is_lse_storing_
)
{
auto
is_dropout_
,
auto
is_lse_storing_
)
{
const
auto
kernel
=
kernel_batched_multiheadattention_forward_xdl_cshuffle
<
const
auto
kernel
=
kernel_batched_multiheadattention_forward_xdl_cshuffle
<
GridwiseGemm
,
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
ADataType
,
// TODO: distiguish A/B datatype
...
@@ -771,6 +785,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -771,6 +785,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
DeviceOp
::
B1GridDesc_BK0_N_BK1
,
DeviceOp
::
B1GridDesc_BK0_N_BK1
,
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemm
::
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
,
typename
GridwiseGemm
::
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
,
typename
GridwiseGemm
::
ZShuffleGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_M4_N4_N5
,
DeviceOp
::
LSEGridDesc_M
,
DeviceOp
::
LSEGridDesc_M
,
typename
GridwiseGemm
::
DefaultBlock2CTileMap
,
typename
GridwiseGemm
::
DefaultBlock2CTileMap
,
ComputeBasePtrOfStridedBatch
,
ComputeBasePtrOfStridedBatch
,
...
@@ -802,6 +817,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -802,6 +817,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
arg
.
b1_grid_desc_bk0_n_bk1_
,
arg
.
b1_grid_desc_bk0_n_bk1_
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_
,
arg
.
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_
,
arg
.
z_grid_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5_
,
arg
.
lse_grid_desc_m_
,
arg
.
lse_grid_desc_m_
,
arg
.
block_2_ctile_map_
,
arg
.
block_2_ctile_map_
,
arg
.
batch_count_
,
arg
.
batch_count_
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_forward_xdl_cshuffle.hpp
View file @
dc8e0148
...
@@ -122,7 +122,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -122,7 +122,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
using
GridwiseGemmPipe
=
remove_cvref_t
<
decltype
(
using
GridwiseGemmPipe
=
remove_cvref_t
<
decltype
(
GridwiseGemmPipeline_Selector
<
PipelineVer
,
NumGemmKPrefetchStage
>
())
>
;
GridwiseGemmPipeline_Selector
<
PipelineVer
,
NumGemmKPrefetchStage
>
())
>
;
// C desc for source in
block
wise copy
// C desc for source in
grid
wise copy
__host__
__device__
static
constexpr
auto
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
(
__host__
__device__
static
constexpr
auto
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
(
const
ZGridDesc_M_N
&
z_grid_desc_m_n
)
////=> for z use
const
ZGridDesc_M_N
&
z_grid_desc_m_n
)
////=> for z use
{
{
...
@@ -143,6 +143,41 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -143,6 +143,41 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
make_tuple
(
Sequence
<
0
,
2
,
4
,
6
>
{},
Sequence
<
1
,
3
,
5
,
7
,
8
,
9
>
{}));
make_tuple
(
Sequence
<
0
,
2
,
4
,
6
>
{},
Sequence
<
1
,
3
,
5
,
7
,
8
,
9
>
{}));
}
}
// C shuffle desc for source in gridwise copy
__host__
__device__
static
constexpr
auto
MakeCShuffleGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_M4_N4_N5
(
const
ZGridDesc_M_N
&
z_grid_desc_m_n
)
////=> for z use to shuffle
{
const
auto
M
=
z_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
z_grid_desc_m_n
.
GetLength
(
I1
);
constexpr
auto
mfma
=
MfmaSelector
<
FloatGemm
,
MPerXdl
,
NPerXdl
>::
selected_mfma
;
constexpr
auto
N3
=
mfma
.
num_groups_per_blk
;
constexpr
auto
N4
=
mfma
.
num_input_blks
;
constexpr
auto
N5
=
mfma
.
group_size
;
// printf("M / MPerBlock %d, ", M / MPerBlock);
// printf("MXdlPerWave %d, " , MXdlPerWave);
// printf("Gemm0MWaves %d, " , Gemm0MWaves);
// printf("MPerXdl / N5 %d, " , MPerXdl / N5);
// printf("N5 %d \n" , N5);
// printf("N / NPerBlock %d, " , N / NPerBlock);
// printf("NXdlPerWave %d, " , NXdlPerWave);
// printf("Gemm0NWaves %d, " , Gemm0NWaves);
// printf("N3 %d, " , N3);
// printf("N4 %d, " , N4);
// printf("N5 %d, " , N5);
return
transform_tensor_descriptor
(
z_grid_desc_m_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
M
/
MPerBlock
,
MXdlPerWave
,
Gemm0MWaves
,
MPerXdl
/
N5
,
N5
)),
make_unmerge_transform
(
make_tuple
(
N
/
NPerBlock
,
NXdlPerWave
,
Gemm0NWaves
,
N3
,
N4
,
N5
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
,
4
,
6
,
9
>
{},
Sequence
<
1
,
3
,
5
,
7
,
8
,
10
>
{}));
}
__device__
static
auto
GetGemm0WaveIdx
()
__device__
static
auto
GetGemm0WaveIdx
()
{
{
const
index_t
thread_id
=
get_thread_local_1d_id
();
const
index_t
thread_id
=
get_thread_local_1d_id
();
...
@@ -381,6 +416,9 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -381,6 +416,9 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
using
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
=
remove_cvref_t
<
decltype
(
using
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
=
remove_cvref_t
<
decltype
(
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
(
ZGridDesc_M_N
{}))
>
;
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
(
ZGridDesc_M_N
{}))
>
;
using
ZShuffleGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_M4_N4_N5
=
remove_cvref_t
<
decltype
(
MakeCShuffleGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_M4_N4_N5
(
ZGridDesc_M_N
{}))
>
;
struct
SharedMemTrait
struct
SharedMemTrait
{
{
// LDS allocation for A and B: be careful of alignment
// LDS allocation for A and B: be careful of alignment
...
@@ -441,6 +479,8 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -441,6 +479,8 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
const
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
&
const
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
&
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
const
ZShuffleGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_M4_N4_N5
&
z_grid_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5
,
const
LSEGridDesc_M
&
lse_grid_desc_m
,
const
LSEGridDesc_M
&
lse_grid_desc_m
,
const
Block2CTileMap
&
block_2_ctile_map
,
const
Block2CTileMap
&
block_2_ctile_map
,
const
C0MatrixMask
&
c0_matrix_mask
,
const
C0MatrixMask
&
c0_matrix_mask
,
...
@@ -856,6 +896,14 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -856,6 +896,14 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
// z vgpr copy to global
// z vgpr copy to global
//
//
// z matrix threadwise desc
// z matrix threadwise desc
// if(get_thread_global_1d_id()==0){
// printf("m2 is %d \n",m2.value);
// printf("n2 is %d \n",n2.value);
// printf("n3 is %d \n",n3.value);
// printf("n4 is %d \n",n4.value);
//}
constexpr
auto
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
=
constexpr
auto
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
// MBlockId
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
// MBlockId
I1
,
// NBlockID
I1
,
// NBlockID
...
@@ -868,20 +916,138 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -868,20 +916,138 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
n3
,
// NInputNum
n3
,
// NInputNum
n4
));
// registerNum
n4
));
// registerNum
constexpr
auto
z_thread_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
//
I1
,
//
m0
,
//
n0
,
//
m1
,
//
n1
,
//
m2
,
// m0
n2
,
// m1
n3
,
// n0
n4
,
// n1
I1
));
// n2
// ignore = z_thread_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
unsigned
short
,
unsigned
short
,
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
.
GetElementSpaceSize
(),
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
.
GetElementSpaceSize
(),
true
>
true
>
z_tenor_buffer
;
z_tenor_tmp_buffer
;
z_tenor_tmp_buffer
.
Clear
();
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
unsigned
short
,
z_thread_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5
.
GetElementSpaceSize
(),
true
>
z_tenor_buffer
;
// z buffer after shuffle
z_tenor_buffer
.
Clear
();
z_tenor_buffer
.
Clear
();
// z matrix global desc
// z matrix global desc
auto
z_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
auto
z_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_z_grid
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
.
GetElementSpaceSize
());
p_z_grid
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
.
GetElementSpaceSize
());
auto
z_grid_tmp_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_z_grid
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
.
GetElementSpaceSize
());
// ignore = z_grid_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5;
// if(get_thread_global_1d_id()==0){
// printf("z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetElementSpaceSize() is %ld \n",
// z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetElementSpaceSize());
// printf("z_grid_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5.GetElementSpaceSize() is
// %ld \n", z_grid_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5.GetElementSpaceSize());
//
//}
const
auto
wave_id
=
GetGemm0WaveIdx
();
const
auto
wave_id
=
GetGemm0WaveIdx
();
const
auto
wave_m_n_id
=
GetGemm0WaveMNIdx
(
wave_id
[
I2
]);
// I2: 0~63
const
auto
wave_m_n_id
=
GetGemm0WaveMNIdx
(
wave_id
[
I2
]);
// I2: 0~63
// if(get_block_1d_id()==0){
// if(get_thread_local_1d_id()==0){
// printf("tid = 0 , wave_m_n_id[I0] & wave_m_n_id[I1] is %d & %d
// \n",wave_m_n_id[I0], wave_m_n_id[I1]);
// }
// if(get_thread_local_1d_id()==1){
// printf("tid = 1 , wave_m_n_id[I0] & wave_m_n_id[I1] is %d & %d
// \n",wave_m_n_id[I0], wave_m_n_id[I1]);
// }
// if(get_thread_local_1d_id()==2){
// printf("tid = 2 , wave_m_n_id[I0] & wave_m_n_id[I1] is %d & %d
// \n",wave_m_n_id[I0], wave_m_n_id[I1]);
// }
// if(get_thread_local_1d_id()==3){
// printf("tid = 3 , wave_m_n_id[I0] & wave_m_n_id[I1] is %d & %d
// \n",wave_m_n_id[I0], wave_m_n_id[I1]);
// }
// if(get_thread_local_1d_id()==32){
// printf("tid = 32 , wave_m_n_id[I0] & wave_m_n_id[I1] is %d & %d
// \n",wave_m_n_id[I0], wave_m_n_id[I1]);
// }
// if(get_thread_local_1d_id()==64){
// printf("tid = 32 , wave_m_n_id[I0] & wave_m_n_id[I1] is %d & %d
// \n",wave_m_n_id[I0], wave_m_n_id[I1]);
// }
//}
auto
z_tmp_thread_copy_vgpr_to_global
=
ThreadwiseTensorSliceTransfer_v1r3
<
ushort
,
ZDataType
,
decltype
(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
),
decltype
(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
),
tensor_operation
::
element_wise
::
PassThrough
,
Sequence
<
I1
,
// MBlockId
I1
,
// NBlockID
m0
,
// MRepeat
n0
,
// NRepeat
m1
,
// MWaveId
n1
,
// NWaveId
m2
,
// MPerXdl
n2
,
// NGroupNum
n3
,
// NInputNum
n4
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
>
,
9
,
// DstVectorDim
1
,
// DstScalarPerVector
InMemoryDataOperationEnum
::
Set
,
1
,
// DstScalarStrideInVector
true
>
{
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
make_multi_index
(
block_work_idx_m
,
// MBlockId
0
,
// NBlockId
0
,
// mrepeat
0
,
// nrepeat
wave_id
[
I0
],
// MWaveId
wave_id
[
I1
],
// NWaveId
wave_m_n_id
[
I1
],
// MPerXdl
0
,
// group
wave_m_n_id
[
I0
],
// NInputIndex
0
),
tensor_operation
::
element_wise
::
PassThrough
{}};
auto
z_shuffle_thread_copy_global_to_vgpr
=
ThreadwiseTensorSliceTransfer_v2
<
ZDataType
,
ushort
,
decltype
(
z_grid_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5
),
decltype
(
z_thread_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5
),
Sequence
<
I1
,
I1
,
m0
,
n0
,
m1
,
n1
,
m2
,
n2
,
n3
,
n4
,
I1
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
>
,
10
,
1
,
1
,
true
/* ResetCoordAfterRun */
>
{
z_grid_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5
,
make_multi_index
(
block_work_idx_m
,
// MBlockId
0
,
// NBlockId
0
,
// mrepeat
0
,
// nrepeat
wave_id
[
I0
],
// MWaveId
wave_id
[
I1
],
// NWaveId
int
(
wave_m_n_id
[
I1
]
/
4
),
// MPerXdl
0
,
// group
wave_m_n_id
[
I0
],
// NInputIndex
0
,
wave_m_n_id
[
I1
]
%
4
)};
auto
z_thread_copy_vgpr_to_global
=
ThreadwiseTensorSliceTransfer_v1r3
<
auto
z_thread_copy_vgpr_to_global
=
ThreadwiseTensorSliceTransfer_v1r3
<
ushort
,
ushort
,
ZDataType
,
ZDataType
,
...
@@ -1060,10 +1226,29 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -1060,10 +1226,29 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
// save z to global
// save z to global
if
(
p_z_grid
)
if
(
p_z_grid
)
{
{
blockwise_dropout
.
template
ApplyDropout_v1r2
<
decltype
(
acc_thread_buf
),
blockwise_dropout
.
template
GenerateZMatrixAttnFwd
<
decltype
(
z_tenor_tmp_buffer
)>(
ph
,
global_elem_id
,
z_tenor_tmp_buffer
);
z_tmp_thread_copy_vgpr_to_global
.
Run
(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
z_tenor_tmp_buffer
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
z_grid_tmp_buf
);
block_sync_lds
();
z_shuffle_thread_copy_global_to_vgpr
.
Run
(
z_grid_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5
,
z_grid_tmp_buf
,
z_thread_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
z_tenor_buffer
);
blockwise_dropout
.
template
ApplyDropoutWithZ
<
decltype
(
acc_thread_buf
),
decltype
(
z_tenor_buffer
),
decltype
(
z_tenor_buffer
),
false
>(
false
>(
acc_thread_buf
,
acc_thread_buf
,
ph
,
global_elem_id
,
z_tenor_buffer
);
z_tenor_buffer
);
z_thread_copy_vgpr_to_global
.
Run
(
z_thread_copy_vgpr_to_global
.
Run
(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
...
@@ -1071,28 +1256,17 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -1071,28 +1256,17 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
z_tenor_buffer
,
z_tenor_buffer
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
z_grid_buf
);
z_grid_buf
);
// static_for<0, n0, 1>{}([&](auto i) {
// blockwise_dropout.template ApplyDropout<decltype(acc_thread_buf),
block_sync_lds
();
// decltype(z_tenor_buffer),
// false,
z_tmp_thread_copy_vgpr_to_global
.
MoveDstSliceWindow
(
// decltype(n0),
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
// decltype(i)>(
make_multi_index
(
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
));
// acc_thread_buf, ph, global_elem_id + id_step * i.value,
// z_tenor_buffer);
z_shuffle_thread_copy_global_to_vgpr
.
MoveSrcSliceWindow
(
//
z_grid_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5
,
// z_thread_copy_vgpr_to_global.Run(
make_multi_index
(
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
));
// z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
// make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
// z_tenor_buffer,
// z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
// z_grid_buf);
// z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
// z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
// make_multi_index(0, 0, 0, 1, 0, 0, 0, 0, 0, 0));
//});
// z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
// z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
// make_multi_index(0, 0, 0, -(n0.value), 0, 0, 0, 0, 0, 0));
z_thread_copy_vgpr_to_global
.
MoveDstSliceWindow
(
z_thread_copy_vgpr_to_global
.
MoveDstSliceWindow
(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
make_multi_index
(
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
));
make_multi_index
(
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
));
...
@@ -1101,7 +1275,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -1101,7 +1275,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
{
{
// ignore = z_grid_buf;
// ignore = z_grid_buf;
// P_dropped
// P_dropped
blockwise_dropout
.
template
ApplyDropout
_v1r1
<
decltype
(
acc_thread_buf
),
false
>(
blockwise_dropout
.
template
ApplyDropout
AttnFwd
<
decltype
(
acc_thread_buf
),
false
>(
acc_thread_buf
,
ph
,
global_elem_id
);
acc_thread_buf
,
ph
,
global_elem_id
);
}
}
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment