Commit bb10822b authored by mtgu0705's avatar mtgu0705
Browse files

Updated the int4 per-group dequant. Meet function bug.

parent 624c6d3e
add_example_executable(example_gemm_multiply_multiply_xdl_fp8 gemm_multiply_multiply_xdl_fp8.cpp) add_example_executable(example_gemm_multiply_multiply_xdl_fp8 gemm_multiply_multiply_xdl_fp8.cpp)
add_example_executable(example_gemm_multiply_multiply_xdl_fp8_ab_scale gemm_multiply_multiply_xdl_fp8_ab_scale.cpp) add_example_executable(example_gemm_multiply_multiply_xdl_fp8_ab_scale gemm_multiply_multiply_xdl_fp8_ab_scale.cpp)
add_example_executable(example_gemm_fp16int8_b_scale gemm_fp16int8_b_scale.cpp) add_example_executable(example_gemm_fp16int8_b_scale gemm_fp16int8_b_scale.cpp)
add_example_executable(example_gemm_fp16int4_b_scale gemm_fp16int4_b_scale.cpp)
add_example_executable(example_gemm_add_add_xdl_fp16 gemm_add_add_xdl_fp16.cpp) add_example_executable(example_gemm_add_add_xdl_fp16 gemm_add_add_xdl_fp16.cpp)
...@@ -61,7 +61,7 @@ using CDEElementOp = PassThrough; ...@@ -61,7 +61,7 @@ using CDEElementOp = PassThrough;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
// static constexpr ck::index_t Scale_Block_M = 128; // static constexpr ck::index_t Scale_Block_M = 128;
static constexpr ck::index_t Scale_Block_N = 128; static constexpr ck::index_t Scale_Block_N = 1;
static constexpr ck::index_t Scale_Block_K = 128; static constexpr ck::index_t Scale_Block_K = 128;
using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_BScale_Xdl_CShuffle_V3 using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_BScale_Xdl_CShuffle_V3
...@@ -217,7 +217,8 @@ int main(int argc, char* argv[]) ...@@ -217,7 +217,8 @@ int main(int argc, char* argv[])
a0_m_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{}); a0_m_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
quant_b0_k_n.GenerateTensorValue(GeneratorTensor_1<QuantDataType>{}); quant_b0_k_n.GenerateTensorValue(GeneratorTensor_1<QuantDataType>{});
// a1_m_k.GenerateTensorValue(GeneratorTensor_1<A1DataType>{}); // a1_m_k.GenerateTensorValue(GeneratorTensor_1<A1DataType>{});
b1_k_n.GenerateTensorValue(GeneratorTensor_1<B1DataType>{}); // b1_k_n.GenerateTensorValue(GeneratorTensor_1<B1DataType>{});
b1_k_n.GenerateTensorValue(GeneratorTensor_3<B1DataType>{0, 1.0});
break; break;
case 3: case 3:
a0_m_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2}); a0_m_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
......
...@@ -39,7 +39,6 @@ template <BlockGemmPipelineVersion BlkGemmPipelineVer, ...@@ -39,7 +39,6 @@ template <BlockGemmPipelineVersion BlkGemmPipelineVer,
index_t KPack> index_t KPack>
constexpr auto BlockGemmBScalePipeline_Selector() constexpr auto BlockGemmBScalePipeline_Selector()
{ {
printf("I'm Here\n");
return BlockwiseGemmXdlops_pipeline_v3_b_scale<BlkGemmPipeSche, return BlockwiseGemmXdlops_pipeline_v3_b_scale<BlkGemmPipeSche,
BlockSize, BlockSize,
ADataType, ADataType,
......
...@@ -440,7 +440,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra ...@@ -440,7 +440,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra
c_thread_buf(Number<c_offset>{}) += c_thread_buf(Number<c_offset>{}) +=
c_thread_buf_per_scale[Number<t>{}] * c_thread_buf_per_scale[Number<t>{}] *
// type_convert<AccDataType>(a_scale_thread_buf[I0]) * // type_convert<AccDataType>(a_scale_thread_buf[I0]) *
type_convert<AccDataType>(b_scale_thread_buf[I0]); type_convert<AccDataType>(b_scale_thread_buf[n0]);
}); });
}); });
}); });
......
...@@ -360,7 +360,8 @@ struct DeviceGemmMultiD_BScale_Xdl_CShuffle_V3 ...@@ -360,7 +360,8 @@ struct DeviceGemmMultiD_BScale_Xdl_CShuffle_V3
return false; return false;
} }
if(ScaleBlockN % NPerBlock != 0 || ScaleBlockK != KPerBlock) // if(ScaleBlockN % NPerBlock != 0 || ScaleBlockK != KPerBlock)
if(ScaleBlockK != KPerBlock)
{ {
printf("Return 1\n"); printf("Return 1\n");
return false; return false;
......
...@@ -1359,15 +1359,15 @@ struct GridwiseGemmMultiD_BScale_xdl_cshuffle_v3 ...@@ -1359,15 +1359,15 @@ struct GridwiseGemmMultiD_BScale_xdl_cshuffle_v3
(a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
KPerBlock); KPerBlock);
const index_t ScaleSliceSizeM = 1; //const index_t ScaleSliceSizeM = 1;
const index_t ScaleSliceSizeN = 1; const index_t ScaleSliceSizeN = NXdlPerWave;
const index_t ScaleSliceSizeK = 1; const index_t ScaleSliceSizeK = 1;
// constexpr auto a_scale_thread_desc = make_naive_tensor_descriptor_packed( // constexpr auto a_scale_thread_desc = make_naive_tensor_descriptor_packed(
// make_tuple(Number<ScaleSliceSizeM>{}, Number<ScaleSliceSizeK>{})); // make_tuple(Number<ScaleSliceSizeM>{}, Number<ScaleSliceSizeK>{}));
constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed( constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<ScaleSliceSizeM>{}, Number<ScaleSliceSizeK>{})); make_tuple(Number<ScaleSliceSizeN>{}, Number<ScaleSliceSizeK>{}));
// auto a_scale_thread_copy = // auto a_scale_thread_copy =
// ThreadwiseTensorSliceTransfer_v2<AScaleType, // ThreadwiseTensorSliceTransfer_v2<AScaleType,
......
...@@ -12,6 +12,7 @@ using half_t = _Float16; ...@@ -12,6 +12,7 @@ using half_t = _Float16;
using int4_t = _BitInt(4); using int4_t = _BitInt(4);
using f8_t = _BitInt(8); using f8_t = _BitInt(8);
using bf8_t = unsigned _BitInt(8); using bf8_t = unsigned _BitInt(8);
using pk_i4_t = unsigned char;
inline constexpr auto next_pow2(uint32_t x) inline constexpr auto next_pow2(uint32_t x)
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment