Commit bb10822b authored by mtgu0705's avatar mtgu0705
Browse files

Updated the int4 per-group dequant. Meet function bug.

parent 624c6d3e
add_example_executable(example_gemm_multiply_multiply_xdl_fp8 gemm_multiply_multiply_xdl_fp8.cpp)
add_example_executable(example_gemm_multiply_multiply_xdl_fp8_ab_scale gemm_multiply_multiply_xdl_fp8_ab_scale.cpp)
add_example_executable(example_gemm_fp16int8_b_scale gemm_fp16int8_b_scale.cpp)
add_example_executable(example_gemm_fp16int4_b_scale gemm_fp16int4_b_scale.cpp)
add_example_executable(example_gemm_add_add_xdl_fp16 gemm_add_add_xdl_fp16.cpp)
......@@ -61,7 +61,7 @@ using CDEElementOp = PassThrough;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
// static constexpr ck::index_t Scale_Block_M = 128;
static constexpr ck::index_t Scale_Block_N = 128;
static constexpr ck::index_t Scale_Block_N = 1;
static constexpr ck::index_t Scale_Block_K = 128;
using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_BScale_Xdl_CShuffle_V3
......@@ -217,7 +217,8 @@ int main(int argc, char* argv[])
a0_m_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
quant_b0_k_n.GenerateTensorValue(GeneratorTensor_1<QuantDataType>{});
// a1_m_k.GenerateTensorValue(GeneratorTensor_1<A1DataType>{});
b1_k_n.GenerateTensorValue(GeneratorTensor_1<B1DataType>{});
// b1_k_n.GenerateTensorValue(GeneratorTensor_1<B1DataType>{});
b1_k_n.GenerateTensorValue(GeneratorTensor_3<B1DataType>{0, 1.0});
break;
case 3:
a0_m_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
......
......@@ -39,7 +39,6 @@ template <BlockGemmPipelineVersion BlkGemmPipelineVer,
index_t KPack>
constexpr auto BlockGemmBScalePipeline_Selector()
{
printf("I'm Here\n");
return BlockwiseGemmXdlops_pipeline_v3_b_scale<BlkGemmPipeSche,
BlockSize,
ADataType,
......
......@@ -440,7 +440,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra
c_thread_buf(Number<c_offset>{}) +=
c_thread_buf_per_scale[Number<t>{}] *
// type_convert<AccDataType>(a_scale_thread_buf[I0]) *
type_convert<AccDataType>(b_scale_thread_buf[I0]);
type_convert<AccDataType>(b_scale_thread_buf[n0]);
});
});
});
......
......@@ -360,7 +360,8 @@ struct DeviceGemmMultiD_BScale_Xdl_CShuffle_V3
return false;
}
if(ScaleBlockN % NPerBlock != 0 || ScaleBlockK != KPerBlock)
// if(ScaleBlockN % NPerBlock != 0 || ScaleBlockK != KPerBlock)
if(ScaleBlockK != KPerBlock)
{
printf("Return 1\n");
return false;
......
......@@ -1359,15 +1359,15 @@ struct GridwiseGemmMultiD_BScale_xdl_cshuffle_v3
(a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
KPerBlock);
const index_t ScaleSliceSizeM = 1;
const index_t ScaleSliceSizeN = 1;
//const index_t ScaleSliceSizeM = 1;
const index_t ScaleSliceSizeN = NXdlPerWave;
const index_t ScaleSliceSizeK = 1;
// constexpr auto a_scale_thread_desc = make_naive_tensor_descriptor_packed(
// make_tuple(Number<ScaleSliceSizeM>{}, Number<ScaleSliceSizeK>{}));
constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<ScaleSliceSizeM>{}, Number<ScaleSliceSizeK>{}));
make_tuple(Number<ScaleSliceSizeN>{}, Number<ScaleSliceSizeK>{}));
// auto a_scale_thread_copy =
// ThreadwiseTensorSliceTransfer_v2<AScaleType,
......
......@@ -12,6 +12,7 @@ using half_t = _Float16;
using int4_t = _BitInt(4);
using f8_t = _BitInt(8);
using bf8_t = unsigned _BitInt(8);
using pk_i4_t = unsigned char;
inline constexpr auto next_pow2(uint32_t x)
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment