Commit 19dd98c8 authored by guangzlu's avatar guangzlu
Browse files

added z tensor datatype choice for fwd pass

parent 79f3caf8
......@@ -413,6 +413,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
// GridwiseGemm
using GridwiseGemm = GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle<
ADataType, // TODO: distinguish A/B datatype
ZDataType,
GemmDataType,
GemmAccDataType,
CShuffleDataType,
......
......@@ -424,6 +424,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
// GridwiseGemm
using GridwiseGemm = GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle<
ADataType, // TODO: distinguish A/B datatype
ZDataType,
GemmDataType,
GemmAccDataType,
CShuffleDataType,
......
......@@ -21,6 +21,7 @@
namespace ck {
template <typename FloatAB,
typename ZDataType,
typename FloatGemm,
typename FloatGemmAcc,
typename FloatCShuffle,
......@@ -424,7 +425,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
const FloatAB* __restrict__ p_b_grid,
const FloatAB* __restrict__ p_b1_grid,
FloatC* __restrict__ p_c_grid,
unsigned short* __restrict__ p_z_grid,
ZDataType* __restrict__ p_z_grid,
FloatLSE* __restrict__ p_lse_grid,
void* __restrict__ p_shared,
const AElementwiseOperation& a_element_op,
......@@ -876,7 +877,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
auto z_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3<
ushort,
ushort,
ZDataType,
decltype(z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5),
decltype(z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5),
tensor_operation::element_wise::PassThrough,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment