Commit 5eec2aef authored by Anthony Chang's avatar Anthony Chang
Browse files

clean up

parent 8079a1b0
...@@ -40,7 +40,7 @@ using DeviceGemmInstance1 = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffl ...@@ -40,7 +40,7 @@ using DeviceGemmInstance1 = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffl
< ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>; < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
// clang-format on // clang-format on
using DeviceGemmInstance = DeviceGemmInstance1; using DeviceGemmInstance = DeviceGemmInstance0;
using ReferenceGemmInstance = ck::tensor_operation::host:: using ReferenceGemmInstance = ck::tensor_operation::host::
ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>; ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>;
......
...@@ -234,11 +234,11 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -234,11 +234,11 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
static constexpr index_t NumAcc0Bias = Acc0BiasDataType::Size(); static constexpr index_t NumAcc0Bias = Acc0BiasDataType::Size();
static constexpr index_t NumAcc1Bias = Acc1BiasDataType::Size(); static constexpr index_t NumAcc1Bias = Acc1BiasDataType::Size();
// TODO ANT: implement bias combination // TODO: implement bias combination
static_assert(NumAcc0Bias == 0 && NumAcc0Bias == 0, "Bias addition is unimplemented"); static_assert(NumAcc0Bias == 0 && NumAcc0Bias == 0, "Bias addition is unimplemented");
#if 0 #if 0
// TODO ANT: use alias // TODO: use alias
static constexpr index_t NumDimGemm0M = NumDimM; static constexpr index_t NumDimGemm0M = NumDimM;
static constexpr index_t NumDimGemm0N = NumDimN; static constexpr index_t NumDimGemm0N = NumDimN;
static constexpr index_t NumDimGemm0K = NumDimK; static constexpr index_t NumDimGemm0K = NumDimK;
...@@ -329,7 +329,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -329,7 +329,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
// v_gs_os_ns -> vgrad_gs_ns_os. O dims last because output is row-major. // v_gs_os_ns -> vgrad_gs_ns_os. O dims last because output is row-major.
// Here directly rearrange lengths/strides before constructing tensor descriptor to reduce // Here directly rearrange lengths/strides before constructing tensor descriptor to reduce
// transformation overhead // transformation overhead
// TODO ANT: This will be much easier when inputs are Gs, Ms, Ns, Os. So there's no need to // TODO: This will be much easier when inputs are Gs, Ms, Ns, Os. So there's no need to
// extract subsequence and shuffle them. // extract subsequence and shuffle them.
const index_t num_dims = NumDimG + NumDimN + NumDimO; const index_t num_dims = NumDimG + NumDimN + NumDimO;
...@@ -410,7 +410,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -410,7 +410,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
// v_gs_os_ns -> vgrad_gs_ns_os. O dims last because output is row-major. // v_gs_os_ns -> vgrad_gs_ns_os. O dims last because output is row-major.
// Here directly rearrange lengths/strides before constructing tensor descriptor to reduce // Here directly rearrange lengths/strides before constructing tensor descriptor to reduce
// transformation overhead // transformation overhead
// TODO ANT: This will be much easier when inputs are Gs, Ms, Ns, Os. So there's no need to // TODO: This will be much easier when inputs are Gs, Ms, Ns, Os. So there's no need to
// extract subsequence and shuffle them. // extract subsequence and shuffle them.
const index_t num_dims = NumDimG + NumDimN + NumDimO; const index_t num_dims = NumDimG + NumDimN + NumDimO;
...@@ -738,7 +738,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -738,7 +738,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
c_grid_desc_g_m_n_, c_grid_desc_g_m_n_,
type_convert<index_t>(lse_grid_desc_m_.GetElementSpaceSize())} type_convert<index_t>(lse_grid_desc_m_.GetElementSpaceSize())}
{ {
// TODO ANT: implement bias addition // TODO: implement bias addition
ignore = p_acc0_biases; ignore = p_acc0_biases;
ignore = p_acc1_biases; ignore = p_acc1_biases;
ignore = acc0_biases_gs_ms_ns_lengths; ignore = acc0_biases_gs_ms_ns_lengths;
...@@ -950,7 +950,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -950,7 +950,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
return false; return false;
} }
// TODO ANT: Check if tensor specialization & strides mismatch // TODO: Check if tensor specialization & strides mismatch
// Check if C permute dimension matches GEMM + GEMM shape // Check if C permute dimension matches GEMM + GEMM shape
const index_t c_g = arg.c_grid_desc_g_m_n_.GetLength(I0); // unpadded const index_t c_g = arg.c_grid_desc_g_m_n_.GetLength(I0); // unpadded
......
...@@ -209,9 +209,6 @@ struct BlockToCTileMap_KSplit_M00_N0_M01Adapt ...@@ -209,9 +209,6 @@ struct BlockToCTileMap_KSplit_M00_N0_M01Adapt
const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n_.GetLength(I0), MPerBlock); const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n_.GetLength(I0), MPerBlock);
const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n_.GetLength(I1), NPerBlock); const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n_.GetLength(I1), NPerBlock);
// TODO ANT: is this necessary?
// block_1d_id = block_1d_id % (M0 * N0 * KSplit_); // hide groups
const index_t idx_ksplit = block_1d_id / (M0 * N0); const index_t idx_ksplit = block_1d_id / (M0 * N0);
block_1d_id = block_1d_id % (M0 * N0); block_1d_id = block_1d_id % (M0 * N0);
......
...@@ -54,8 +54,7 @@ template <typename SrcData, ...@@ -54,8 +54,7 @@ template <typename SrcData,
typename SrcDesc, typename SrcDesc,
typename DstDesc, typename DstDesc,
typename ElementwiseOperation, typename ElementwiseOperation,
typename SliceLengths, // TODO ANT: can we generalize to allow sub-wg slice transfer? need typename SliceLengths,
// to distinguish what dimensions are spread across waves
typename DimAccessOrder, typename DimAccessOrder,
index_t DstVectorDim, index_t DstVectorDim,
index_t DstScalarPerVector, index_t DstScalarPerVector,
...@@ -137,7 +136,6 @@ struct ThreadwiseTensorSliceTransfer_v1r3 ...@@ -137,7 +136,6 @@ struct ThreadwiseTensorSliceTransfer_v1r3
static_for<0, DstScalarPerVector, 1>{}([&](auto i) { static_for<0, DstScalarPerVector, 1>{}([&](auto i) {
constexpr index_t src_offset = src_desc.CalculateOffset( constexpr index_t src_offset = src_desc.CalculateOffset(
src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector); src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector);
// Sequence<num_access, idx_1d.value, i.value, src_offset>{}.foo();
SrcData v; SrcData v;
......
...@@ -148,7 +148,7 @@ check_err(const Range& out, ...@@ -148,7 +148,7 @@ check_err(const Range& out,
{ {
max_err = err > max_err ? err : max_err; max_err = err > max_err ? err : max_err;
err_count++; err_count++;
if(err_count < 32) if(err_count < 5)
{ {
std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
<< "] != ref[" << i << "]: " << o << " != " << r << std::endl; << "] != ref[" << i << "]: " << o << " != " << r << std::endl;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment