Commit 27dc055b authored by aska-0096's avatar aska-0096
Browse files

fix a host tensor bug and clean up flash-attn code

parent 4ddda63b
......@@ -43,9 +43,10 @@ using AElementOp = ck::tensor_operation::element_wise::PassThrough;
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
using CDEElementOp = ck::tensor_operation::element_wise::Add;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
static constexpr auto ABSpec = ck::tensor_operation::device::TensorSpecialization::Packed;
static constexpr auto ASpec = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto BSpec = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto DESpec = ck::tensor_operation::device::TensorSpecialization::Default;
using DeviceOpInstanceKKNN =
......@@ -64,18 +65,18 @@ using DeviceOpInstanceKKNN =
BElementOp,
CDEElementOp,
GemmSpec,
ABSpec,
ABSpec,
ASpec,
BSpec,
DESpec,
256,
128,
256,
8,
128,
4,
8,
16,
16,
4,
4,
2,
S<4, 64, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
......@@ -252,21 +253,6 @@ int main(int argc, char* argv[])
ck::index_t K0 = 2048;
// A[G0, G1, M0, M1, K0]
std::vector<ck::index_t> a_gs_ms_ks_lengths{G0, G1, M0, M1, K0};
std::vector<ck::index_t> a_gs_ms_ks_strides{G1 * M0 * M1 * K0, M0 * M1 * K0, M1 * K0, K0, 1};
// B[G0, G1, N0, N1, K0]
std::vector<ck::index_t> b_gs_ns_ks_lengths{G0, G1, N0, N1, K0};
std::vector<ck::index_t> b_gs_ns_ks_strides{G1 * N0 * N1 * K0, N0 * N1 * K0, N1 * K0, K0, 1};
// D[G0, G1, M0, N0, M1, N1]
std::vector<ck::index_t> d_gs_ms_ns_lengths{G0, G1, M0, M1, N0, N1};
std::vector<ck::index_t> d_gs_ms_ns_strides{G1 * N0 * N1, N0 * N1, 0, 0, N1, 1};
// E[G0, G1, M0, N0, M1, N1]
std::vector<ck::index_t> e_gs_ms_ns_lengths{G0, G1, M0, M1, N0, N1};
std::vector<ck::index_t> e_gs_ms_ns_strides{
G1 * M0 * N0 * M1 * N1, M0 * N0 * M1 * N1, N0 * M1 * N1, N1, M1 * N1, 1};
if(argc == 1)
{
// use default case
......@@ -277,13 +263,43 @@ int main(int argc, char* argv[])
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
}
else if(argc == 11)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
G0 = std::stoi(argv[4]);
G1 = std::stoi(argv[5]);
M0 = std::stoi(argv[6]);
M1 = std::stoi(argv[7]);
N0 = std::stoi(argv[8]);
N1 = std::stoi(argv[9]);
K0 = std::stoi(argv[10]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4-10: G0, G1, M0, M1, N0, N1, K0\n");
exit(0);
}
// A[G0, G1, M0, M1, K0]
std::vector<ck::index_t> a_gs_ms_ks_lengths{G0, G1, M0, M1, K0};
std::vector<ck::index_t> a_gs_ms_ks_strides{G1 * M0 * M1 * K0, M0 * M1 * K0, M1 * K0, K0, 1};
// B[G0, G1, N0, N1, K0]
std::vector<ck::index_t> b_gs_ns_ks_lengths{G0, G1, N0, N1, K0};
std::vector<ck::index_t> b_gs_ns_ks_strides{G1 * N0 * N1 * K0, N0 * N1 * K0, N1 * K0, K0, 1};
// D[G0, G1, M0, N0, M1, N1]
std::vector<ck::index_t> d_gs_ms_ns_lengths{G0, G1, M0, M1, N0, N1};
std::vector<ck::index_t> d_gs_ms_ns_strides{G1 * N0 * N1, N0 * N1, 0, 0, N1, 1};
// E[G0, G1, M0, N0, M1, N1]
std::vector<ck::index_t> e_gs_ms_ns_lengths{G0, G1, M0, M1, N0, N1};
std::vector<ck::index_t> e_gs_ms_ns_strides{
G1 * M0 * N0 * M1 * N1, M0 * N0 * M1 * N1, N0 * M1 * N1, N1, M1 * N1, 1};
Tensor<ADataType> a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides);
Tensor<BDataType> b_gs_ns_ks(b_gs_ns_ks_lengths, b_gs_ns_ks_strides);
Tensor<DDataType> d_gs_ms_ns(d_gs_ms_ns_lengths, d_gs_ms_ns_strides);
......
......@@ -29,12 +29,12 @@ float launch_and_time_kernel(const StreamConfig& stream_config,
block_dim.y,
block_dim.z);
const int nrepeat = 1;
const int nrepeat = 100;
// printf("Warm up 1 time\n");
printf("Warm up 1 time\n");
// warm up
// kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args...);
kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args...);
printf("Start running %d times...\n", nrepeat);
......
......@@ -771,6 +771,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
{
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>))
{
printf("DeviceOp: Arch check failure\n");
return false;
}
}
......@@ -785,6 +786,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
arg.e_grid_desc_m_n_,
arg.block_2_ctile_map_))
{
printf("GridwiseOp: Validity check failure\n");
return false;
}
......@@ -799,6 +801,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
if(!(arg.a_mz_stride_ == 1 &&
arg.a_grid_desc_k0_m_k1_.GetLength(I1) % ABlockTransferSrcScalarPerVector == 0))
{
printf("DeviceOp: Vector Access A-m check failure\n");
return false;
}
}
......@@ -807,6 +810,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
if(!(arg.a_kz_stride_ == 1 &&
arg.a_grid_desc_k0_m_k1_.GetLength(I2) % ABlockTransferSrcScalarPerVector == 0))
{
printf("DeviceOp: Vector Access A-k check failure\n");
return false;
}
}
......@@ -817,6 +821,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
if(!(arg.b_nz_stride_ == 1 &&
arg.b_grid_desc_k0_n_k1_.GetLength(I1) % BBlockTransferSrcScalarPerVector == 0))
{
printf("DeviceOp: Vector Access B-n check failure\n");
return false;
}
}
......@@ -825,6 +830,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
if(!(arg.b_kz_stride_ == 1 &&
arg.b_grid_desc_k0_n_k1_.GetLength(I2) % BBlockTransferSrcScalarPerVector == 0))
{
printf("DeviceOp: Vector Access B-k check failure\n");
return false;
}
}
......@@ -838,6 +844,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
CDEShuffleBlockTransferScalarPerVector_NPerBlock ==
0))
{
printf("DeviceOp: Vector Access D-n check failure\n");
valid_d_access = false;
}
});
......@@ -854,6 +861,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
0) ||
CDEShuffleBlockTransferScalarPerVector_NPerBlock == 1))
{
printf("DeviceOp: Vector Access E-n check failure\n");
return false;
}
......
......@@ -352,6 +352,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
const auto K = a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2);
const auto N = b1_grid_desc_l0_n_l1.GetLength(I1);
printf("M = %d, L = %d, K = %d, N = %d\n", M, L, K, N);
const auto KPerBlock = K0PerBlock * K1Value;
if(!(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1)))
{
......@@ -730,7 +732,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
// dst Rowlane
// 0x76543210 0xfedcba98
// src Rowlane
0x76543210, 0xfedcba98>{tensor_operation::element_wise::PassThrough{}};
0x76543210, 0xfedcba98,
false>{tensor_operation::element_wise::PassThrough{}};
// B1 matrix blockwise copy
auto b1_blockwise_copy =
......
......@@ -148,14 +148,12 @@ __global__ void
const Block2CTileMap block_2_etile_map)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__))
//printf("entry kernel launch");
__shared__ char p_shared[GridwiseOp::GetSharedMemoryNumberOfByte()];
const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
//printf("before compute_ptr_offset call");
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(
......@@ -170,13 +168,9 @@ __global__ void
DsPointer p_ds_grid_grp;
//printf("before allocate pointer d");
static_for<0, NumDTensor, 1>{}(
[&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_batch_offset[i]; });
//printf("before entry");
GridwiseOp::template Run<HasMainKBlockLoop>(p_a_grid + a_batch_offset,
p_b_grid + b_batch_offset,
p_ds_grid_grp,
......@@ -469,16 +463,23 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
if(!valid)
{
printf("GridwiseOp: D descriptor dimension check failure\n");
return false;
}
if(!(M == e_grid_desc_m_n.GetLength(I0) && N == e_grid_desc_m_n.GetLength(I1) &&
K0 == b_grid_desc_k0_n_k1.GetLength(I0) && K1 == a_grid_desc_k0_m_k1.GetLength(I2) &&
K1 == b_grid_desc_k0_n_k1.GetLength(I2)))
{
printf("GridwiseOp: ABE descriptor dimension cross check failure\n");
return false;
}
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0))
{
printf("GridwiseOp: Problemsize descriptor dimension check failure\n");
return false;
}
// check gridwise gemm pipeline
const auto num_k_loop = K0 / K0PerBlock;
......@@ -570,7 +571,6 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
const CDEElementwiseOperation& cde_element_op,
const Block2CTileMap& block_2_ctile_map)
{
//printf("safe entry");
// clang-format off
/*******************************************************************************/
// Memory buffer zone.
......@@ -716,7 +716,6 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
c_thread_buf,
K0BlockMainLoop);
/*******************************************************************************/
//printf("safe 1");
// write out to C, implement shuffle
{
constexpr auto c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs =
......
......@@ -1311,10 +1311,11 @@ template <typename SrcData,
typename ElementwiseOperation,
typename SliceLengths,
typename DimAccessOrder,
index_t DstVectorDim,
index_t DstScalarPerVector,
index_t DstVectorDim,
index_t DstScalarPerVector,
uint32_t LowEightRowlaneIdx,
uint32_t HighEightRowLaneIdx,
bool IntraRowSwizzlePerm,
typename enable_if<SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
bool>::type = false>
struct ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow
......@@ -1389,29 +1390,33 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow
constexpr index_t dst_offset = dst_desc.CalculateOffset(
dst_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector);
SrcData v;
SrcData v_this_row, v_theother_row;
// int type temp value due to intrinsic requirement
int temp = 0;
// apply element-wise operation
element_op_(v, src_buf[Number<src_offset>{}]);
element_op_(v_this_row, src_buf[Number<src_offset>{}]);
// apply intra-row swizzle permute
if constexpr(IntraRowSwizzlePerm){
// origin: 0xfedcba98, 0x76543210
temp = __builtin_amdgcn_permlane16(temp, type_convert<int>(v_this_row), 0xeca86420, 0xfdb97531, 1, 0);
v_this_row = type_convert<float>(temp);
}
// apply inter-row permute.
temp = __builtin_amdgcn_permlanex16(temp, type_convert<int>(v_this_row), LowEightRowlaneIdx, HighEightRowLaneIdx, 1, 0);
v_theother_row = type_convert<float>(temp);
if(get_thread_local_1d_id() % 32 < 16){
// apply type convert
dst_buf(Number<dst_offset>{}) = type_convert<DstData>(v);
dst_buf(Number<dst_offset>{}) = type_convert<DstData>(v_this_row);
dst_buf(Number<dst_offset + DstScalarPerVector>{}) = type_convert<DstData>(v_theother_row);
}
else{
// apply type convert
dst_buf(Number<dst_offset + DstScalarPerVector>{}) = type_convert<DstData>(v);
}
SrcData d = 0;
int temp = 0;
temp = __builtin_amdgcn_permlanex16(temp, type_convert<int>(v),
LowEightRowlaneIdx, HighEightRowLaneIdx, 1, 0);
d = type_convert<float>(temp);
if(get_thread_local_1d_id() % 32 < 16){
dst_buf(Number<dst_offset + DstScalarPerVector>{}) = type_convert<DstData>(d);
}
else{
dst_buf(Number<dst_offset>{}) = type_convert<DstData>(d);
dst_buf(Number<dst_offset + DstScalarPerVector>{}) = type_convert<DstData>(v_this_row);
dst_buf(Number<dst_offset>{}) = type_convert<DstData>(v_theother_row);
}
});
});
......
......@@ -972,7 +972,6 @@ inline __host__ __device__ constexpr int type_convert<int, float>(float x)
float fp32;
int int32;
} u = {x};
// u.fp32 = x;
return u.int32;
}
......@@ -985,7 +984,6 @@ inline __host__ __device__ constexpr float type_convert<float, int>(int x)
int int32;
float fp32;
} u = {x};
// u.fp32 = x;
return u.fp32;
}
......
......@@ -396,7 +396,7 @@ struct Tensor
}
case 6: {
auto f = [&](auto i0, auto i1, auto i2, auto i3, auto i4, auto i5) {
(*this)(i0, i1, i2, i3, i4) = g(i0, i1, i2, i3, i4, i5);
(*this)(i0, i1, i2, i3, i4, i5) = g(i0, i1, i2, i3, i4, i5);
};
make_ParallelTensorFunctor(f,
mDesc.GetLengths()[0],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment