Commit 19dd98c8 authored by guangzlu's avatar guangzlu
Browse files

added z tensor datatype choice for fwd pass

parent 79f3caf8
...@@ -413,6 +413,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -413,6 +413,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle< using GridwiseGemm = GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle<
ADataType, // TODO: distinguish A/B datatype ADataType, // TODO: distinguish A/B datatype
ZDataType,
GemmDataType, GemmDataType,
GemmAccDataType, GemmAccDataType,
CShuffleDataType, CShuffleDataType,
......
...@@ -424,6 +424,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle ...@@ -424,6 +424,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle< using GridwiseGemm = GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle<
ADataType, // TODO: distinguish A/B datatype ADataType, // TODO: distinguish A/B datatype
ZDataType,
GemmDataType, GemmDataType,
GemmAccDataType, GemmAccDataType,
CShuffleDataType, CShuffleDataType,
......
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
namespace ck { namespace ck {
template <typename FloatAB, template <typename FloatAB,
typename ZDataType,
typename FloatGemm, typename FloatGemm,
typename FloatGemmAcc, typename FloatGemmAcc,
typename FloatCShuffle, typename FloatCShuffle,
...@@ -424,7 +425,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -424,7 +425,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
const FloatAB* __restrict__ p_b1_grid, const FloatAB* __restrict__ p_b1_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
unsigned short* __restrict__ p_z_grid, ZDataType* __restrict__ p_z_grid,
FloatLSE* __restrict__ p_lse_grid, FloatLSE* __restrict__ p_lse_grid,
void* __restrict__ p_shared, void* __restrict__ p_shared,
const AElementwiseOperation& a_element_op, const AElementwiseOperation& a_element_op,
...@@ -876,7 +877,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -876,7 +877,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
auto z_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3< auto z_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3<
ushort, ushort,
ushort, ZDataType,
decltype(z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5), decltype(z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5),
decltype(z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5), decltype(z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5),
tensor_operation::element_wise::PassThrough, tensor_operation::element_wise::PassThrough,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment