Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
03c2ba3a
Commit
03c2ba3a
authored
Dec 04, 2024
by
aska-0096
Browse files
bug fix + performance opt + clangformat
parent
1a324dfb
Changes
14
Show whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
372 additions
and
315 deletions
+372
-315
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_scale_selector.hpp
...block/blockwise_gemm_pipeline_xdlops_b_scale_selector.hpp
+25
-25
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_b_scale.hpp
...n/gpu/block/blockwise_gemm_pipeline_xdlops_v1_b_scale.hpp
+21
-21
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2_b_scale.hpp
...n/gpu/block/blockwise_gemm_pipeline_xdlops_v2_b_scale.hpp
+48
-50
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_b_scale.hpp
...n/gpu/block/blockwise_gemm_pipeline_xdlops_v3_b_scale.hpp
+57
-36
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4_b_scale.hpp
...n/gpu/block/blockwise_gemm_pipeline_xdlops_v4_b_scale.hpp
+21
-24
include/ck/tensor_operation/gpu/device/device_gemm_v2.hpp
include/ck/tensor_operation/gpu/device/device_gemm_v2.hpp
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_b_scale.hpp
...n/gpu/device/impl/device_gemm_xdl_cshuffle_v3_b_scale.hpp
+24
-6
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
...or_operation/gpu/element/unary_element_wise_operation.hpp
+25
-17
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp
...ration/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp
+63
-68
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
...operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
+1
-1
library/include/ck/library/tensor_operation_instance/gpu/gemm_b_scale.hpp
...ck/library/tensor_operation_instance/gpu/gemm_b_scale.hpp
+15
-15
library/src/tensor_operation_instance/gpu/gemm_b_scale/device_gemm_b_scale_xdl_f16_i4_f16/device_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn_mem_v2_default_instance.cpp
...scale_xdl_f16_i4_f16_mk_nk_mn_mem_v2_default_instance.cpp
+13
-4
profiler/include/profiler/profile_gemm_b_scale_impl.hpp
profiler/include/profiler/profile_gemm_b_scale_impl.hpp
+40
-40
profiler/src/profile_gemm_b_scale.cpp
profiler/src/profile_gemm_b_scale.cpp
+18
-7
No files found.
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_scale_selector.hpp
View file @
03c2ba3a
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_b_scale.hpp
View file @
03c2ba3a
...
...
@@ -170,7 +170,7 @@ struct BlockwiseGemmXdlops_pipeline_v1_b_scale<BlockGemmPipelineScheduler::Intra
typename
BBlockBuffer
,
typename
BBlockTransferStep
,
typename
CThreadBuffer
,
//BScale Thread Copy
//
BScale Thread Copy
typename
BScaleGridBuffer
,
typename
BScaleGridDesc
,
typename
BScaleThreadDesc
,
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2_b_scale.hpp
View file @
03c2ba3a
...
...
@@ -719,12 +719,12 @@ struct BlockwiseGemmXdlops_pipeline_v2_b_scale<BlockGemmPipelineScheduler::Inter
const
BBlockTransferStep
&
b_block_copy_step
,
CThreadBuffer
&
c_thread_buf
,
const
BScaleGridDesc
&
b_scale_grid_desc
,
//BScaleThreadCopy
//
BScaleThreadCopy
const
BScaleThreadDesc
&
b_scale_thread_desc
,
BScaleThreadTransfer
&
b_scale_thread_copy
,
const
BScaleGridBuffer
&
b_scale_grid_buf
,
const
BScaleThreadTransferStep
&
b_scale_thread_copy_step
,
//num loop
//
num loop
index_t
num_loop
,
index_t
num_loop_per_scale
)
const
{
...
...
@@ -864,14 +864,14 @@ struct BlockwiseGemmXdlops_pipeline_v2_b_scale<BlockGemmPipelineScheduler::Inter
}
});
// static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
// static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t)
// {
// constexpr index_t c_offset =
// c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t));
// c_thread_buf(Number<c_offset>{}) +=
// c_thread_buf_per_scale[Number<t>{}] *
// type_convert<AccDataType>(b_scale_thread_buf[n0]);
// });
});
});
__builtin_amdgcn_sched_barrier
(
0
);
...
...
@@ -986,7 +986,6 @@ struct BlockwiseGemmXdlops_pipeline_v2_b_scale<BlockGemmPipelineScheduler::Inter
// c_thread_buf_per_scale[Number<t>{}] *
// type_convert<AccDataType>(b_scale_thread_buf[n0]);
// });
});
});
__builtin_amdgcn_sched_barrier
(
0
);
...
...
@@ -1084,7 +1083,6 @@ struct BlockwiseGemmXdlops_pipeline_v2_b_scale<BlockGemmPipelineScheduler::Inter
// c_thread_buf_per_scale[Number<t>{}] *
// type_convert<AccDataType>(b_scale_thread_buf[n0]);
// });
});
});
__builtin_amdgcn_sched_barrier
(
0
);
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_b_scale.hpp
View file @
03c2ba3a
...
...
@@ -295,25 +295,24 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra
BBlockBuffer
&
b_block_buf
,
const
BBlockTransferStep
&
b_block_copy_step
,
CThreadBuffer
&
c_thread_buf
,
//BScaleThreadCopy
//
BScaleThreadCopy
const
BScaleGridDesc
&
b_scale_grid_desc
,
const
BScaleThreadDesc
&
b_scale_thread_desc
,
BScaleThreadTransfer
&
b_scale_thread_copy
,
const
BScaleGridBuffer
&
b_scale_grid_buf
,
const
BScaleThreadTransferStep
&
b_scale_thread_copy_step
,
//num loop
//
num loop
index_t
num_loop
,
index_t
num_loop_per_scale
)
const
{
__builtin_amdgcn_sched_barrier
(
0
);
ignore
=
num_loop_per_scale
;
auto
a_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
>
(
a_thread_desc_
.
GetElementSpaceSize
());
auto
b_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
>
(
b_thread_desc_
.
GetElementSpaceSize
());
//B scale buffer
//
B scale buffer
auto
b_scale_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
>
(
b_scale_thread_desc
.
GetElementSpaceSize
());
...
...
@@ -328,14 +327,23 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra
b_scale_thread_copy
.
Run
(
b_scale_grid_desc
,
b_scale_grid_buf
,
b_scale_thread_desc
,
make_tuple
(
n0
,
I0
,
I0
),
make_tuple
(
n0
,
I0
),
b_scale_thread_buf
);
b_scale_thread_copy
.
MoveSrcSliceWindow
(
b_scale_grid_desc
,
b_scale_thread_copy_step
.
At
(
Number
<
0
>
{}));
});
if
(
num_loop_per_scale
==
1
)
{
b_scale_thread_copy
.
MoveSrcSliceWindow
(
b_scale_grid_desc
,
b_scale_thread_copy_step
.
At
(
Number
<
2
>
{}));
}
else
{
b_scale_thread_copy
.
MoveSrcSliceWindow
(
b_scale_grid_desc
,
b_scale_thread_copy_step
.
At
(
Number
<
1
>
{}));
}
// Local prefill 1
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
);
...
...
@@ -350,7 +358,13 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra
// Initialize C
c_thread_buf
.
Clear
();
auto
c_thread_buf_per_scale
=
remove_cvref_t
<
decltype
(
c_thread_buf
)
>
();
// need actually?
StaticBufferTupleOfVector
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
1
,
xdlops_gemm
.
GetRegSizePerXdlops
(),
true
>
c_thread_buf_per_scale
;
// Local prefetch 1
block_sync_lds
();
...
...
@@ -415,8 +429,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra
// constexpr index_t c_offset =
// c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf_per_scale
.
GetVectorTypeReference
(
I0
));
});
...
...
@@ -455,15 +468,23 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra
b_scale_thread_copy
.
Run
(
b_scale_grid_desc
,
b_scale_grid_buf
,
b_scale_thread_desc
,
make_tuple
(
n0
,
I0
,
I0
),
make_tuple
(
n0
,
I0
),
b_scale_thread_buf
);
b_scale_thread_copy
.
MoveSrcSliceWindow
(
b_scale_grid_desc
,
b_scale_thread_copy_step
.
At
(
Number
<
0
>
{}));
});
b_scale_thread_copy
.
MoveSrcSliceWindow
(
b_scale_grid_desc
,
b_scale_thread_copy_step
.
At
(
Number
<
1
>
{}));
if
((
i
+
2
)
%
num_loop_per_scale
==
0
)
{
b_scale_thread_copy
.
MoveSrcSliceWindow
(
b_scale_grid_desc
,
b_scale_thread_copy_step
.
At
(
Number
<
2
>
{}));
}
else
{
b_scale_thread_copy
.
MoveSrcSliceWindow
(
b_scale_grid_desc
,
b_scale_thread_copy_step
.
At
(
Number
<
1
>
{}));
}
HotLoopScheduler
();
__builtin_amdgcn_sched_barrier
(
0
);
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4_b_scale.hpp
View file @
03c2ba3a
...
...
@@ -268,13 +268,13 @@ struct BlockwiseGemmXdlops_pipeline_v4_b_scale<BlockGemmPipelineScheduler::Intra
BBlockBuffer
&
b_block_buf
,
const
BBlockTransferStep
&
b_block_copy_step
,
CThreadBuffer
&
c_thread_buf
,
//BScaleThreadCopy
//
BScaleThreadCopy
const
BScaleGridDesc
&
b_scale_grid_desc
,
const
BScaleThreadDesc
&
b_scale_thread_desc
,
BScaleThreadTransfer
&
b_scale_thread_copy
,
const
BScaleGridBuffer
&
b_scale_grid_buf
,
const
BScaleThreadTransferStep
&
b_scale_thread_copy_step
,
//num loop
//
num loop
index_t
num_loop
,
index_t
num_loop_per_scale
)
const
{
...
...
@@ -284,7 +284,7 @@ struct BlockwiseGemmXdlops_pipeline_v4_b_scale<BlockGemmPipelineScheduler::Intra
auto
b_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
>
(
b_thread_desc_
.
GetElementSpaceSize
());
//B scale buffer
//
B scale buffer
auto
b_scale_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
>
(
b_scale_thread_desc
.
GetElementSpaceSize
());
...
...
@@ -409,11 +409,11 @@ struct BlockwiseGemmXdlops_pipeline_v4_b_scale<BlockGemmPipelineScheduler::Intra
make_tuple
(
n0
,
I0
),
b_scale_thread_bufs
(
lds_read_reg_buf
));
b_scale_thread_copy
.
MoveSrcSliceWindow
(
b_scale_grid_desc
,
b_scale_thread_copy_step
.
At
(
Number
<
0
>
{}));
b_scale_thread_copy
.
MoveSrcSliceWindow
(
b_scale_grid_desc
,
b_scale_thread_copy_step
.
At
(
Number
<
0
>
{}));
});
b_scale_thread_copy
.
MoveSrcSliceWindow
(
b_scale_grid_desc
,
b_scale_thread_copy_step
.
At
(
Number
<
1
>
{}));
b_scale_thread_copy
.
MoveSrcSliceWindow
(
b_scale_grid_desc
,
b_scale_thread_copy_step
.
At
(
Number
<
1
>
{}));
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
.
At
(
lds_write_buf
),
vmem_buf
);
...
...
@@ -426,7 +426,6 @@ struct BlockwiseGemmXdlops_pipeline_v4_b_scale<BlockGemmPipelineScheduler::Intra
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
c_thread_buf_per_scale
.
Clear
();
...
...
@@ -452,8 +451,7 @@ struct BlockwiseGemmXdlops_pipeline_v4_b_scale<BlockGemmPipelineScheduler::Intra
// constexpr index_t c_offset =
// c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf_per_scale
.
GetVectorTypeReference
(
I0
));
});
...
...
@@ -462,7 +460,8 @@ struct BlockwiseGemmXdlops_pipeline_v4_b_scale<BlockGemmPipelineScheduler::Intra
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
t
));
c_thread_buf
(
Number
<
c_offset
>
{})
+=
c_thread_buf_per_scale
[
Number
<
t
>
{}]
*
type_convert
<
AccDataType
>
(
b_scale_thread_bufs
(
mfma_reg_buf
)[
n0
]);
type_convert
<
AccDataType
>
(
b_scale_thread_bufs
(
mfma_reg_buf
)[
n0
]);
});
});
});
...
...
@@ -521,7 +520,6 @@ struct BlockwiseGemmXdlops_pipeline_v4_b_scale<BlockGemmPipelineScheduler::Intra
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
.
At
(
lds_write_buf
),
vmem_buf
);
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
.
At
(
lds_write_buf
),
vmem_buf
);
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
c_thread_buf_per_scale
.
Clear
();
...
...
@@ -640,7 +638,6 @@ struct BlockwiseGemmXdlops_pipeline_v4_b_scale<BlockGemmPipelineScheduler::Intra
};
auto
CompFunc
=
[
&
](
auto
mfma_reg_buf
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
c_thread_buf_per_scale
.
Clear
();
...
...
include/ck/tensor_operation/gpu/device/device_gemm_v2.hpp
View file @
03c2ba3a
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_b_scale.hpp
View file @
03c2ba3a
...
...
@@ -218,7 +218,12 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2BScale<ALayout,
};
constexpr
index_t
minimum_occupancy
=
BlkGemmPipeSched
==
BlockGemmPipelineScheduler
::
Intrawave
?
1
:
2
;
BlkGemmPipeSched
==
BlockGemmPipelineScheduler
::
Intrawave
?
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v3
&&
MPerBlock
*
NPerBlock
*
KPerBlock
*
sizeof
(
ADataType
)
>
128
*
128
*
64
*
2
)
?
1
:
2
:
2
;
if
(
has_main_k_block_loop
)
{
...
...
@@ -664,7 +669,20 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2BScale<ALayout,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
)
{
return
Argument
{
p_a
,
p_b
,
p_c
,
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
,
p_b_scale
,
KBatch
,
a_element_op
,
b_element_op
,
c_element_op
};
return
Argument
{
p_a
,
p_b
,
p_c
,
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
,
p_b_scale
,
KBatch
,
a_element_op
,
b_element_op
,
c_element_op
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
...
...
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
View file @
03c2ba3a
...
...
@@ -13,9 +13,9 @@ namespace ck {
__host__
__device__
inline
half4_t
pki4_to_half4
(
int
q
)
{
const
int
LO
=
0x000f000f
;
const
int
HI
=
0x00f000f0
;
const
int
EX
=
0x64006400
;
const
expr
int
LO
=
0x000f000f
;
const
expr
int
HI
=
0x00f000f0
;
const
expr
int
EX
=
0x64006400
;
// Guarantee that the `(a & b) | c` operations are LOP3s.
// int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
// int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
...
...
@@ -23,9 +23,9 @@ __host__ __device__ inline half4_t pki4_to_half4(int q)
int
hi
=
amd_assembly_and_or_b32
(
q
,
HI
,
EX
);
// We want signed int4 outputs, hence we fuse the `-8` symmetric zero point
// directly into `SUB` and `ADD`.
const
int
SUB
=
0xE408E408
;
//-8
const
int
MUL
=
0x2c002c00
;
// 1/16
const
int
ADD
=
0xd480d480
;
//-79
const
expr
int
SUB
=
0xE408E408
;
//-8
const
expr
int
MUL
=
0x2c002c00
;
// 1/16
const
expr
int
ADD
=
0xd480d480
;
//-79
vector_type
<
half_t
,
4
>
res
;
...
...
@@ -34,7 +34,15 @@ __host__ __device__ inline half4_t pki4_to_half4(int q)
res
.
template
AsType
<
half2_t
>()(
Number
<
1
>
{})
=
amd_assembly_pk_fma_f16
(
bit_cast
<
half2_t
>
(
hi
),
bit_cast
<
half2_t
>
(
MUL
),
bit_cast
<
half2_t
>
(
ADD
));
#if 0
asm volatile("v_and_or_b32 %0, %4, %5, %7 \n \
v_and_or_b32 %1, %4, %6, %7 \n \
v_pk_add_f16 %2, %0, %8 \n \
v_pk_fma_f16 %3, %1, %9, %10 \
"
: "=v"(lo), "=v"(hi), "=v"(res.template AsType<half2_t>()(Number<0>{})), "=v"(res.template AsType<half2_t>()(Number<1>{}))
: "v"(q), "v"(LO), "v"(HI), "s"(EX), "s"(SUB), "v"(MUL), "s"(ADD), "0"(lo), "1"(hi));
#endif
return
res
.
template
AsType
<
half4_t
>()[
Number
<
0
>
{}];
}
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp
View file @
03c2ba3a
...
...
@@ -48,12 +48,7 @@ __global__ void
// karg);
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
CGlobalMemoryDataOperation
,
TailNum
>(
karg
.
p_a_grid
,
karg
.
p_b_grid
,
karg
.
p_c_grid
,
karg
.
p_b_scale_grid
,
p_shared
,
karg
);
karg
.
p_a_grid
,
karg
.
p_b_grid
,
karg
.
p_c_grid
,
karg
.
p_b_scale_grid
,
p_shared
,
karg
);
#else
ignore
=
karg
;
#endif // end of if (defined(__gfx9__))
...
...
@@ -1303,11 +1298,8 @@ struct GridwiseGemm_xdl_cshuffle_v3
// B Scale grid and buffer
const
auto
b_scale_grid_desc_bn_ak
=
make_naive_tensor_descriptor
(
make_tuple
(
math
::
integer_divide_ceil
(
problem
.
N
,
ScaleBlockN
),
math
::
integer_divide_ceil
(
problem
.
K
,
ScaleBlockK
),
ScaleBlockK
),
make_tuple
(
math
::
integer_divide_ceil
(
problem
.
K
,
ScaleBlockK
),
1
,
0
));
math
::
integer_divide_ceil
(
problem
.
K
,
ScaleBlockK
)),
make_tuple
(
math
::
integer_divide_ceil
(
problem
.
K
,
ScaleBlockK
),
1
));
const
auto
b_scale_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_b_scale_grid
,
b_scale_grid_desc_bn_ak
.
GetElementSpaceSize
());
...
...
@@ -1438,12 +1430,12 @@ struct GridwiseGemm_xdl_cshuffle_v3
KPerBlock
);
// b scale
static_assert
(
KPerBlock
<=
ScaleBlockK
);
static_assert
(
KPerBlock
<=
ScaleBlockK
);
const
index_t
ScaleSliceSizeN
=
NXdlPerWave
;
const
index_t
ScaleSliceSizeK
=
1
;
constexpr
auto
b_scale_thread_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
ScaleSliceSizeN
>
{},
Number
<
ScaleSliceSizeK
>
{}
,
Number
<
1
>
{}
));
make_tuple
(
Number
<
ScaleSliceSizeN
>
{},
Number
<
ScaleSliceSizeK
>
{}));
constexpr
index_t
NWaves
=
NPerBlock
/
(
NXdlPerWave
*
NPerXdl
);
...
...
@@ -1455,22 +1447,24 @@ struct GridwiseGemm_xdl_cshuffle_v3
BScaleType
,
decltype
(
b_scale_grid_desc_bn_ak
),
decltype
(
b_scale_thread_desc
),
Sequence
<
1
,
ScaleSliceSizeK
,
1
>
,
Sequence
<
0
,
1
,
2
>
,
Sequence
<
1
,
ScaleSliceSizeK
>
,
Sequence
<
0
,
1
>
,
1
,
1
,
1
,
false
>
(
b_scale_grid_desc_bn_ak
,
make_multi_index
(
block_n_id
*
NPerBlock
/
ScaleBlockN
+
b_thread_offset
,
0
,
0
));
make_multi_index
(
block_n_id
*
NPerBlock
/
ScaleBlockN
+
b_thread_offset
,
0
));
constexpr
auto
b_scale_thread_slice_copy_step
=
make_tuple
(
make_multi_index
(
NWaves
*
NPerXdl
,
0
,
0
),
make_multi_index
(
-
NPerBlock
,
KPerBlock
/
ScaleBlockK
,
KPerBlock
%
ScaleBlockK
));
make_tuple
(
make_multi_index
(
NWaves
*
NPerXdl
,
0
),
make_multi_index
(
-
NPerBlock
,
0
),
make_multi_index
(
-
NPerBlock
,
1
));
const
index_t
num_k_block_per_scale
=
ScaleBlockK
/
KPerBlock
;
blockwise_gemm_pipeline
.
template
Run
<
HasMainKBlockLoop
,
TailNum
>(
a_grid_desc_ak0_m_ak1
,
blockwise_gemm_pipeline
.
template
Run
<
HasMainKBlockLoop
,
TailNum
>(
a_grid_desc_ak0_m_ak1
,
a_block_desc_ak0_m_ak1
,
a_blockwise_copy
,
a_grid_buf
,
...
...
@@ -1867,7 +1861,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
auto
b_block_buf_ping
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
bit_cast
<
BDataType
*>
(
static_cast
<
char
*>
(
p_shared_0
)
+
a_block_space_size_aligned
*
sizeof
(
ADataType
)
/
APackedSize
),
a_block_space_size_aligned
*
sizeof
(
ADataType
)
/
APackedSize
),
b_block_desc_bk0_n_bk1
.
GetElementSpaceSize
());
auto
a_block_buf_pong
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
...
...
@@ -1875,7 +1869,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
auto
b_block_buf_pong
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
bit_cast
<
BDataType
*>
(
bit_cast
<
char
*>
(
p_shared_1
)
+
a_block_space_size_aligned
*
sizeof
(
ADataType
)
/
APackedSize
),
a_block_space_size_aligned
*
sizeof
(
ADataType
)
/
APackedSize
),
b_block_desc_bk0_n_bk1
.
GetElementSpaceSize
());
auto
a_block_bufs
=
make_tuple
(
a_block_buf_ping
,
a_block_buf_pong
);
...
...
@@ -1924,7 +1918,8 @@ struct GridwiseGemm_xdl_cshuffle_v3
const
index_t
num_k_block_per_scale
=
ScaleBlockK
/
KPerBlock
;
blockwise_gemm_pipeline
.
template
Run
<
HasMainKBlockLoop
,
TailNum
>(
a_grid_desc_ak0_m_ak1
,
blockwise_gemm_pipeline
.
template
Run
<
HasMainKBlockLoop
,
TailNum
>(
a_grid_desc_ak0_m_ak1
,
a_block_desc_ak0_m_ak1
,
a_blockwise_copy
,
a_grid_buf
,
...
...
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
View file @
03c2ba3a
library/include/ck/library/tensor_operation_instance/gpu/gemm_b_scale.hpp
View file @
03c2ba3a
...
...
@@ -40,8 +40,8 @@ template <typename ADataType,
typename
BLayout
,
typename
CLayout
,
index_t
ScaleBlockK
>
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DeviceGemmV2BScale
<
ALayout
,
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DeviceGemmV2BScale
<
ALayout
,
BLayout
,
CLayout
,
ADataType
,
...
...
library/src/tensor_operation_instance/gpu/gemm_b_scale/device_gemm_b_scale_xdl_f16_i4_f16/device_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn_mem_v2_default_instance.cpp
View file @
03c2ba3a
...
...
@@ -8,13 +8,22 @@ namespace tensor_operation {
namespace
device
{
namespace
instance
{
void
add_device_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn_mem_v2_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2BScale
<
Row
,
Col
,
Row
,
F16
,
I4
,
F16
,
F16
,
1
,
128
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
)
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2BScale
<
Row
,
Col
,
Row
,
F16
,
I4
,
F16
,
F16
,
1
,
128
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
//device_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn_mem_instances<Interwave, GemmDefault>{});
//
device_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn_mem_instances<Interwave, GemmDefault>{});
device_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn_mem_instances
<
Intrawave
,
GemmDefault
>
{});
}
...
...
profiler/include/profiler/profile_gemm_b_scale_impl.hpp
View file @
03c2ba3a
...
...
@@ -72,7 +72,8 @@ bool profile_gemm_b_scale_impl(int do_verification,
Tensor
<
ADataType
>
a_m_k
(
f_host_tensor_descriptor
(
M
,
K
,
StrideA
,
ALayout
{}));
Tensor
<
BDataType
>
b_k_n
(
f_host_tensor_descriptor
(
K
,
N
,
StrideB
,
BLayout
{}));
Tensor
<
BDataType
>
b_k_n_permute
(
f_host_tensor_descriptor
(
K
,
N
,
StrideB
,
BLayout
{}));
Tensor
<
BScaleDataType
>
b1_k_n
(
f_host_tensor_descriptor
((
K
+
ScaleBlockK
-
1
)
/
ScaleBlockK
,
// K direction group size is ScaleBlockK
Tensor
<
BScaleDataType
>
b1_k_n
(
f_host_tensor_descriptor
(
(
K
+
ScaleBlockK
-
1
)
/
ScaleBlockK
,
// K direction group size is ScaleBlockK
N
,
// N direction group size is 1
Scale_Stride_BN
,
BLayout
{}));
...
...
@@ -167,8 +168,7 @@ bool profile_gemm_b_scale_impl(int do_verification,
i4
=
i4
-
8
;
v_b
=
ck
::
type_convert
<
float
>
(
i4
);
b_k_n_dequant
(
k
,
n
)
=
ck
::
type_convert
<
float
>
(
v_b
)
*
b_k_n_dequant
(
k
,
n
)
=
ck
::
type_convert
<
float
>
(
v_b
)
*
ck
::
type_convert
<
float
>
(
b1_k_n
(
k
/
ScaleBlockK
,
n
));
}
}
...
...
@@ -291,8 +291,8 @@ bool profile_gemm_b_scale_impl(int do_verification,
{
auto
kbatch_curr
=
kbatch_list
[
i
];
auto
argument_ptr
=
op_ptr
->
MakeArgumentPointer
(
static_cast
<
ADataType
*>
(
a_device_buf
.
GetDeviceBuffer
()),
auto
argument_ptr
=
op_ptr
->
MakeArgumentPointer
(
static_cast
<
ADataType
*>
(
a_device_buf
.
GetDeviceBuffer
()),
static_cast
<
BDataType
*>
(
b_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CDataType
*>
(
c_device_buf
.
GetDeviceBuffer
()),
M
,
...
...
profiler/src/profile_gemm_b_scale.cpp
View file @
03c2ba3a
...
...
@@ -82,7 +82,14 @@ int profile_gemm_b_scale(int argc, char* argv[])
const
int
StrideB
=
std
::
stoi
(
argv
[
13
]);
const
int
StrideC
=
std
::
stoi
(
argv
[
14
]);
const
int
KBatch
=
std
::
stoi
(
argv
[
15
]);
printf
(
"M:%d, N:%d, K:%d, StrideA:%d, StrideB:%d, StrideC:%d, KBatch:%d
\n
"
,
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
,
KBatch
);
printf
(
"M:%d, N:%d, K:%d, StrideA:%d, StrideB:%d, StrideC:%d, KBatch:%d
\n
"
,
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
,
KBatch
);
int
n_warmup
=
1
;
int
n_iter
=
10
;
...
...
@@ -156,14 +163,18 @@ int profile_gemm_b_scale(int argc, char* argv[])
return
pass
?
0
:
1
;
};
// if(data_type == GemmDataType::F16_I4_F16 && layout == GemmMatrixLayout::MK_NK_MN && B_scale_block == BScaleBlockTile::K_64)
// if(data_type == GemmDataType::F16_I4_F16 && layout == GemmMatrixLayout::MK_NK_MN &&
// B_scale_block == BScaleBlockTile::K_64)
// {
// return profile(F16{}, I4{}, F16{}, F16{}, F32{}, F16{}, ck::Number<64>{}, Row{}, Col{}, Row{});
// return profile(F16{}, I4{}, F16{}, F16{}, F32{}, F16{}, ck::Number<64>{}, Row{}, Col{},
// Row{});
// }
if
(
data_type
==
GemmDataType
::
F16_I4_F16
&&
layout
==
GemmMatrixLayout
::
MK_NK_MN
&&
B_scale_block
==
BScaleBlockTile
::
K_128
)
if
(
data_type
==
GemmDataType
::
F16_I4_F16
&&
layout
==
GemmMatrixLayout
::
MK_NK_MN
&&
B_scale_block
==
BScaleBlockTile
::
K_128
)
{
printf
(
"F16_I4_F16 MK_NK_MN K_128
\n
"
);
return
profile
(
F16
{},
I4
{},
F16
{},
F16
{},
F32
{},
F16
{},
ck
::
Number
<
128
>
{},
Row
{},
Col
{},
Row
{});
return
profile
(
F16
{},
I4
{},
F16
{},
F16
{},
F32
{},
F16
{},
ck
::
Number
<
128
>
{},
Row
{},
Col
{},
Row
{});
}
else
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment