Commit 5eec2aef authored by Anthony Chang's avatar Anthony Chang
Browse files

clean up

parent 8079a1b0
......@@ -40,7 +40,7 @@ using DeviceGemmInstance1 = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffl
< ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
// clang-format on
using DeviceGemmInstance = DeviceGemmInstance1;
using DeviceGemmInstance = DeviceGemmInstance0;
using ReferenceGemmInstance = ck::tensor_operation::host::
ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>;
......
......@@ -234,11 +234,11 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
static constexpr index_t NumAcc0Bias = Acc0BiasDataType::Size();
static constexpr index_t NumAcc1Bias = Acc1BiasDataType::Size();
// TODO ANT: implement bias combination
// TODO: implement bias combination
static_assert(NumAcc0Bias == 0 && NumAcc0Bias == 0, "Bias addition is unimplemented");
#if 0
// TODO ANT: use alias
// TODO: use alias
static constexpr index_t NumDimGemm0M = NumDimM;
static constexpr index_t NumDimGemm0N = NumDimN;
static constexpr index_t NumDimGemm0K = NumDimK;
......@@ -329,7 +329,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
// v_gs_os_ns -> vgrad_gs_ns_os. O dims last because output is row-major.
// Here directly rearrange lengths/strides before constructing tensor descriptor to reduce
// transformation overhead
// TODO ANT: This will be much easier when inputs are Gs, Ms, Ns, Os. So there's no need to
// TODO: This will be much easier when inputs are Gs, Ms, Ns, Os. So there's no need to
// extract subsequence and shuffle them.
const index_t num_dims = NumDimG + NumDimN + NumDimO;
......@@ -410,7 +410,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
// v_gs_os_ns -> vgrad_gs_ns_os. O dims last because output is row-major.
// Here directly rearrange lengths/strides before constructing tensor descriptor to reduce
// transformation overhead
// TODO ANT: This will be much easier when inputs are Gs, Ms, Ns, Os. So there's no need to
// TODO: This will be much easier when inputs are Gs, Ms, Ns, Os. So there's no need to
// extract subsequence and shuffle them.
const index_t num_dims = NumDimG + NumDimN + NumDimO;
......@@ -738,7 +738,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
c_grid_desc_g_m_n_,
type_convert<index_t>(lse_grid_desc_m_.GetElementSpaceSize())}
{
// TODO ANT: implement bias addition
// TODO: implement bias addition
ignore = p_acc0_biases;
ignore = p_acc1_biases;
ignore = acc0_biases_gs_ms_ns_lengths;
......@@ -950,7 +950,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
return false;
}
// TODO ANT: Check if tensor specialization & strides mismatch
// TODO: Check if tensor specialization & strides mismatch
// Check if C permute dimension matches GEMM + GEMM shape
const index_t c_g = arg.c_grid_desc_g_m_n_.GetLength(I0); // unpadded
......
......@@ -209,9 +209,6 @@ struct BlockToCTileMap_KSplit_M00_N0_M01Adapt
const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n_.GetLength(I0), MPerBlock);
const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n_.GetLength(I1), NPerBlock);
// TODO ANT: is this necessary?
// block_1d_id = block_1d_id % (M0 * N0 * KSplit_); // hide groups
const index_t idx_ksplit = block_1d_id / (M0 * N0);
block_1d_id = block_1d_id % (M0 * N0);
......
......@@ -54,8 +54,7 @@ template <typename SrcData,
typename SrcDesc,
typename DstDesc,
typename ElementwiseOperation,
typename SliceLengths, // TODO ANT: can we generalize to allow sub-wg slice transfer? need
// to distinguish what dimensions are spread across waves
typename SliceLengths,
typename DimAccessOrder,
index_t DstVectorDim,
index_t DstScalarPerVector,
......@@ -137,7 +136,6 @@ struct ThreadwiseTensorSliceTransfer_v1r3
static_for<0, DstScalarPerVector, 1>{}([&](auto i) {
constexpr index_t src_offset = src_desc.CalculateOffset(
src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector);
// Sequence<num_access, idx_1d.value, i.value, src_offset>{}.foo();
SrcData v;
......
......@@ -148,7 +148,7 @@ check_err(const Range& out,
{
max_err = err > max_err ? err : max_err;
err_count++;
if(err_count < 32)
if(err_count < 5)
{
std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
<< "] != ref[" << i << "]: " << o << " != " << r << std::endl;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment