Commit daebef99 authored by fsx950223's avatar fsx950223
Browse files

use int32 as z output

parent cf2490e0
...@@ -95,7 +95,7 @@ __global__ void ...@@ -95,7 +95,7 @@ __global__ void
const index_t global_thread_id = get_thread_global_1d_id(); const index_t global_thread_id = get_thread_global_1d_id();
ck::philox ph(seed, global_thread_id, offset); ck::philox ph(seed, global_thread_id, offset);
unsigned short* z_matrix_ptr = auto z_matrix_ptr =
(arg_ptr[group_id].p_z_grid_ == nullptr ? nullptr (arg_ptr[group_id].p_z_grid_ == nullptr ? nullptr
: arg_ptr[group_id].p_z_grid_ + z_batch_offset); : arg_ptr[group_id].p_z_grid_ + z_batch_offset);
...@@ -535,6 +535,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -535,6 +535,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1< using GridwiseGemm = GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1<
DataType, // TODO: distinguish A/B datatype DataType, // TODO: distinguish A/B datatype
ZDataType,
GemmDataType, GemmDataType,
GemmAccDataType, GemmAccDataType,
CShuffleDataType, CShuffleDataType,
......
...@@ -95,7 +95,7 @@ __global__ void ...@@ -95,7 +95,7 @@ __global__ void
const index_t global_thread_id = get_thread_global_1d_id(); const index_t global_thread_id = get_thread_global_1d_id();
ck::philox ph(seed, global_thread_id, offset); ck::philox ph(seed, global_thread_id, offset);
unsigned short* z_matrix_ptr = auto z_matrix_ptr =
(arg_ptr[group_id].p_z_grid_ == nullptr ? nullptr (arg_ptr[group_id].p_z_grid_ == nullptr ? nullptr
: arg_ptr[group_id].p_z_grid_ + z_batch_offset); : arg_ptr[group_id].p_z_grid_ + z_batch_offset);
...@@ -528,6 +528,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -528,6 +528,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2< using GridwiseGemm = GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2<
DataType, // TODO: distinguish A/B datatype DataType, // TODO: distinguish A/B datatype
ZDataType,
GemmDataType, GemmDataType,
GemmAccDataType, GemmAccDataType,
CShuffleDataType, CShuffleDataType,
......
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
namespace ck { namespace ck {
template <typename DataType, template <typename DataType,
typename ZDataType,
typename GemmDataType, typename GemmDataType,
typename FloatGemmAcc, typename FloatGemmAcc,
typename FloatCShuffle, typename FloatCShuffle,
...@@ -1236,7 +1237,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -1236,7 +1237,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
typename YGradGridDesc_O0_M_O1> typename YGradGridDesc_O0_M_O1>
__device__ static void Run(const DataType* __restrict__ p_q_grid, __device__ static void Run(const DataType* __restrict__ p_q_grid,
const DataType* __restrict__ p_k_grid, const DataType* __restrict__ p_k_grid,
unsigned short* __restrict__ p_z_grid, ZDataType* __restrict__ p_z_grid,
const DataType* __restrict__ p_v_grid, const DataType* __restrict__ p_v_grid,
const DataType* __restrict__ p_y_grid, const DataType* __restrict__ p_y_grid,
const FloatLSE* __restrict__ p_lse_grid, const FloatLSE* __restrict__ p_lse_grid,
...@@ -1552,7 +1553,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -1552,7 +1553,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
auto z_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3< auto z_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3<
ushort, ushort,
ushort, ZDataType,
decltype(z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5), decltype(z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5),
decltype(z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5), decltype(z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5),
tensor_operation::element_wise::PassThrough, tensor_operation::element_wise::PassThrough,
......
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
namespace ck { namespace ck {
template <typename DataType, template <typename DataType,
typename ZDataType,
typename GemmDataType, typename GemmDataType,
typename FloatGemmAcc, typename FloatGemmAcc,
typename FloatCShuffle, typename FloatCShuffle,
...@@ -1146,7 +1147,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -1146,7 +1147,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
typename YGradGridDesc_M0_O_M1> typename YGradGridDesc_M0_O_M1>
__device__ static void Run(const DataType* __restrict__ p_q_grid, __device__ static void Run(const DataType* __restrict__ p_q_grid,
const DataType* __restrict__ p_k_grid, const DataType* __restrict__ p_k_grid,
unsigned short* __restrict__ p_z_grid, ZDataType* __restrict__ p_z_grid,
const DataType* __restrict__ p_v_grid, const DataType* __restrict__ p_v_grid,
const DataType* __restrict__ p_y_grid, const DataType* __restrict__ p_y_grid,
const FloatLSE* __restrict__ p_lse_grid, const FloatLSE* __restrict__ p_lse_grid,
...@@ -1484,7 +1485,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -1484,7 +1485,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
auto z_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3< auto z_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3<
ushort, ushort,
ushort, ZDataType,
decltype(z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5), decltype(z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5),
decltype(z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5), decltype(z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5),
tensor_operation::element_wise::PassThrough, tensor_operation::element_wise::PassThrough,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment