Commit daebef99 authored by fsx950223's avatar fsx950223
Browse files

use int32 as z output

parent cf2490e0
......@@ -95,7 +95,7 @@ __global__ void
const index_t global_thread_id = get_thread_global_1d_id();
ck::philox ph(seed, global_thread_id, offset);
unsigned short* z_matrix_ptr =
auto z_matrix_ptr =
(arg_ptr[group_id].p_z_grid_ == nullptr ? nullptr
: arg_ptr[group_id].p_z_grid_ + z_batch_offset);
......@@ -535,6 +535,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
// GridwiseGemm
using GridwiseGemm = GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1<
DataType, // TODO: distinguish A/B datatype
ZDataType,
GemmDataType,
GemmAccDataType,
CShuffleDataType,
......
......@@ -95,7 +95,7 @@ __global__ void
const index_t global_thread_id = get_thread_global_1d_id();
ck::philox ph(seed, global_thread_id, offset);
unsigned short* z_matrix_ptr =
auto z_matrix_ptr =
(arg_ptr[group_id].p_z_grid_ == nullptr ? nullptr
: arg_ptr[group_id].p_z_grid_ + z_batch_offset);
......@@ -528,6 +528,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
// GridwiseGemm
using GridwiseGemm = GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2<
DataType, // TODO: distinguish A/B datatype
ZDataType,
GemmDataType,
GemmAccDataType,
CShuffleDataType,
......
......@@ -21,6 +21,7 @@
namespace ck {
template <typename DataType,
typename ZDataType,
typename GemmDataType,
typename FloatGemmAcc,
typename FloatCShuffle,
......@@ -1236,7 +1237,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
typename YGradGridDesc_O0_M_O1>
__device__ static void Run(const DataType* __restrict__ p_q_grid,
const DataType* __restrict__ p_k_grid,
unsigned short* __restrict__ p_z_grid,
ZDataType* __restrict__ p_z_grid,
const DataType* __restrict__ p_v_grid,
const DataType* __restrict__ p_y_grid,
const FloatLSE* __restrict__ p_lse_grid,
......@@ -1552,7 +1553,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
auto z_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3<
ushort,
ushort,
ZDataType,
decltype(z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5),
decltype(z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5),
tensor_operation::element_wise::PassThrough,
......
......@@ -21,6 +21,7 @@
namespace ck {
template <typename DataType,
typename ZDataType,
typename GemmDataType,
typename FloatGemmAcc,
typename FloatCShuffle,
......@@ -1146,7 +1147,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
typename YGradGridDesc_M0_O_M1>
__device__ static void Run(const DataType* __restrict__ p_q_grid,
const DataType* __restrict__ p_k_grid,
unsigned short* __restrict__ p_z_grid,
ZDataType* __restrict__ p_z_grid,
const DataType* __restrict__ p_v_grid,
const DataType* __restrict__ p_y_grid,
const FloatLSE* __restrict__ p_lse_grid,
......@@ -1484,7 +1485,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
auto z_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3<
ushort,
ushort,
ZDataType,
decltype(z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5),
decltype(z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5),
tensor_operation::element_wise::PassThrough,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment