Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
837d9056
"model/models/vscode:/vscode.git/clone" did not exist on "ec9eb28f4c3481d58c6da38ee488cb8cd5379256"
Commit
837d9056
authored
Feb 17, 2025
by
mtgu0705
Browse files
Added b preshuffle pipeline v3 support.
parent
e53e764d
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
96 additions
and
9 deletions
+96
-9
example/01_gemm/gemm_xdl_fp8_pk_i4_bpreshuffle_v3.cpp
example/01_gemm/gemm_xdl_fp8_pk_i4_bpreshuffle_v3.cpp
+21
-2
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v3.hpp
.../block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v3.hpp
+75
-7
No files found.
example/01_gemm/gemm_xdl_fp8_pk_i4_bpreshuffle_v3.cpp
View file @
837d9056
...
@@ -28,9 +28,9 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa
...
@@ -28,9 +28,9 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa
static
constexpr
bool
PermuteA
=
false
;
static
constexpr
bool
PermuteA
=
false
;
static
constexpr
bool
PermuteB
=
false
;
static
constexpr
bool
PermuteB
=
false
;
static
constexpr
ck
::
index_t
KPerBlock
=
128
;
// clang-format off
// clang-format off
#if 0
using DeviceGemmV2Instance =
using DeviceGemmV2Instance =
ck::tensor_operation::device::DeviceGemm_Xdl_CShuffleV3_BPreshuffle<
ck::tensor_operation::device::DeviceGemm_Xdl_CShuffleV3_BPreshuffle<
ALayout, BLayout, CLayout,
ALayout, BLayout, CLayout,
...
@@ -38,7 +38,7 @@ using DeviceGemmV2Instance =
...
@@ -38,7 +38,7 @@ using DeviceGemmV2Instance =
AElementOp, BElementOp, CElementOp, GemmDefault,
AElementOp, BElementOp, CElementOp, GemmDefault,
256,
256,
128, 128,
128, 128,
KPerBlock
,
16
,
32
,
256
, 16, 32,
32, 32,
32, 32,
4, 1,
4, 1,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>,
...
@@ -47,7 +47,26 @@ using DeviceGemmV2Instance =
...
@@ -47,7 +47,26 @@ using DeviceGemmV2Instance =
2, 32, 32, 0,
2, 32, 32, 0,
1, 1, S<1, 32, 1, 8>, 4,
1, 1, S<1, 32, 1, 8>, 4,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, F8, F8, PermuteA, PermuteB>;
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, F8, F8, PermuteA, PermuteB>;
#else
using
DeviceGemmV2Instance
=
ck
::
tensor_operation
::
device
::
DeviceGemm_Xdl_CShuffleV3_BPreshuffle
<
ALayout
,
BLayout
,
CLayout
,
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
CShuffleDataType
,
AElementOp
,
BElementOp
,
CElementOp
,
GemmDefault
,
256
,
256
,
256
,
128
,
16
,
32
,
32
,
32
,
4
,
4
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
32
,
32
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
ck
::
BlockGemmPipelineScheduler
::
Intrawave
,
ck
::
BlockGemmPipelineVersion
::
v3
,
F8
,
F8
,
PermuteA
,
PermuteB
>
;
#endif
// clang-format on
// clang-format on
template
<
typename
ProblemType
>
template
<
typename
ProblemType
>
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v3.hpp
View file @
837d9056
...
@@ -510,10 +510,13 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
...
@@ -510,10 +510,13 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
__builtin_amdgcn_sched_barrier
(
0
);
__builtin_amdgcn_sched_barrier
(
0
);
auto
a_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
>
(
auto
a_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
>
(
a_thread_desc_
.
GetElementSpaceSize
());
a_thread_desc_
.
GetElementSpaceSize
());
auto
b_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
>
(
auto
b_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
BDataType
>
(
b_thread_desc_
.
GetElementSpaceSize
());
auto
b_thread_dequant_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
>
(
b_thread_desc_
.
GetElementSpaceSize
());
b_thread_desc_
.
GetElementSpaceSize
());
StaticallyIndexedArray
<
decltype
(
b_thread_buf
),
Number
<
2
>
{}
>
b_thread_bufs
;
StaticallyIndexedArray
<
decltype
(
b_thread_buf
),
Number
<
2
>
{}
>
b_thread_bufs
;
StaticallyIndexedArray
<
decltype
(
b_thread_dequant_buf
),
Number
<
2
>
{}
>
b_thread_dequant_bufs
;
constexpr
auto
b_block_origin_idx
=
make_tuple
(
I0
,
I0
,
I0
,
I0
);
constexpr
auto
b_block_origin_idx
=
make_tuple
(
I0
,
I0
,
I0
,
I0
);
// Global prefetch A1 B1
// Global prefetch A1 B1
...
@@ -545,6 +548,13 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
...
@@ -545,6 +548,13 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
make_tuple
(
I0
,
I0
,
I0
,
k0
,
I0
,
I0
),
make_tuple
(
I0
,
I0
,
I0
,
k0
,
I0
,
I0
),
a_thread_buf
);
a_thread_buf
);
});
});
// B VGPR->VGPR dequant
b_thread_dequant_copy_
.
Run
(
b_block_desc_n0_n1_k0_k1
,
b_block_origin_idx
,
b_thread_bufs
(
I0
),
b_thread_desc_
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
b_thread_dequant_bufs
(
I0
));
// Initialize C
// Initialize C
c_thread_buf
.
Clear
();
c_thread_buf
.
Clear
();
...
@@ -594,9 +604,9 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
...
@@ -594,9 +604,9 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
I0
,
I0
,
ik
))
>
{}];
ik
))
>
{}];
b_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
b_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
b_thread_bufs
[
mfma_reg_buf
]
b_thread_
dequant_
bufs
[
mfma_reg_buf
]
[
Number
<
b_thread_desc_
.
CalculateOffset
(
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
n0
,
I0
,
k0
,
ik
))
>
{}];
make_tuple
(
n0
,
I0
,
k0
,
ik
))
>
{}];
});
});
using
mfma_input_type
=
using
mfma_input_type
=
...
@@ -633,6 +643,13 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
...
@@ -633,6 +643,13 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
I0
),
I0
),
a_thread_buf
);
a_thread_buf
);
});
});
// B VGPR->VGPR dequant
b_thread_dequant_copy_
.
Run
(
b_block_desc_n0_n1_k0_k1
,
b_block_origin_idx
,
b_thread_bufs
(
local_read_buf
),
b_thread_desc_
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
b_thread_dequant_bufs
(
local_read_buf
));
}
}
else
else
{
{
...
@@ -652,6 +669,13 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
...
@@ -652,6 +669,13 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
I0
),
I0
),
a_thread_buf
);
a_thread_buf
);
});
});
// B VGPR->VGPR dequant
b_thread_dequant_copy_
.
Run
(
b_block_desc_n0_n1_k0_k1
,
b_block_origin_idx
,
b_thread_bufs
(
mfma_reg_buf
),
b_thread_desc_
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
b_thread_dequant_bufs
(
mfma_reg_buf
));
}
}
HotLoopScheduler
(
m0
);
HotLoopScheduler
(
m0
);
...
@@ -691,7 +715,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
...
@@ -691,7 +715,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
%
2
,
I0
,
I0
,
k0
,
I0
,
ik
))
>
{}];
make_tuple
(
m0
%
2
,
I0
,
I0
,
k0
,
I0
,
ik
))
>
{}];
b_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
b_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
b_thread_bufs
[
I0
][
Number
<
b_thread_desc_
.
CalculateOffset
(
b_thread_
dequant_
bufs
[
I0
][
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
n0
,
I0
,
k0
,
ik
))
>
{}];
make_tuple
(
n0
,
I0
,
k0
,
ik
))
>
{}];
});
});
...
@@ -720,6 +744,13 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
...
@@ -720,6 +744,13 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
make_tuple
(
Number
<
(
m0
+
1
)
%
2
>
{},
I0
,
I0
,
k0
,
I0
,
I0
),
make_tuple
(
Number
<
(
m0
+
1
)
%
2
>
{},
I0
,
I0
,
k0
,
I0
,
I0
),
a_thread_buf
);
a_thread_buf
);
});
});
// B VGPR->VGPR dequant
b_thread_dequant_copy_
.
Run
(
b_block_desc_n0_n1_k0_k1
,
b_block_origin_idx
,
b_thread_bufs
(
I1
),
b_thread_desc_
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
b_thread_dequant_bufs
(
I1
));
}
}
else
else
{
{
...
@@ -732,6 +763,13 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
...
@@ -732,6 +763,13 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
make_tuple
(
Number
<
(
m0
+
1
)
%
2
>
{},
I0
,
I0
,
k0
,
I0
,
I0
),
make_tuple
(
Number
<
(
m0
+
1
)
%
2
>
{},
I0
,
I0
,
k0
,
I0
,
I0
),
a_thread_buf
);
a_thread_buf
);
});
});
// B VGPR->VGPR dequant
b_thread_dequant_copy_
.
Run
(
b_block_desc_n0_n1_k0_k1
,
b_block_origin_idx
,
b_thread_bufs
(
I0
),
b_thread_desc_
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
b_thread_dequant_bufs
(
I0
));
}
}
EpilogueScheduler_1
(
m0
);
EpilogueScheduler_1
(
m0
);
...
@@ -748,7 +786,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
...
@@ -748,7 +786,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
(
m0
+
HotloopLocalBufSwitch
)
%
2
,
I0
,
I0
,
k0
,
I0
,
ik
))
>
{}];
(
m0
+
HotloopLocalBufSwitch
)
%
2
,
I0
,
I0
,
k0
,
I0
,
ik
))
>
{}];
b_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
b_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
b_thread_bufs
[
I1
][
Number
<
b_thread_desc_
.
CalculateOffset
(
b_thread_
dequant_
bufs
[
I1
][
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
n0
,
I0
,
k0
,
ik
))
>
{}];
make_tuple
(
n0
,
I0
,
k0
,
ik
))
>
{}];
});
});
...
@@ -776,6 +814,13 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
...
@@ -776,6 +814,13 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
Number
<
(
m0
+
1
+
HotloopLocalBufSwitch
)
%
2
>
{},
I0
,
I0
,
k0
,
I0
,
I0
),
Number
<
(
m0
+
1
+
HotloopLocalBufSwitch
)
%
2
>
{},
I0
,
I0
,
k0
,
I0
,
I0
),
a_thread_buf
);
a_thread_buf
);
});
});
// B VGPR->VGPR dequant
b_thread_dequant_copy_
.
Run
(
b_block_desc_n0_n1_k0_k1
,
b_block_origin_idx
,
b_thread_bufs
(
I1
),
b_thread_desc_
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
b_thread_dequant_bufs
(
I1
));
EpilogueScheduler_2
();
EpilogueScheduler_2
();
}
}
...
@@ -797,7 +842,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
...
@@ -797,7 +842,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
%
2
,
I0
,
I0
,
k0
,
I0
,
ik
))
>
{}];
make_tuple
(
m0
%
2
,
I0
,
I0
,
k0
,
I0
,
ik
))
>
{}];
b_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
b_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
b_thread_bufs
[
I0
][
Number
<
b_thread_desc_
.
CalculateOffset
(
b_thread_
dequant_
bufs
[
I0
][
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
n0
,
I0
,
k0
,
ik
))
>
{}];
make_tuple
(
n0
,
I0
,
k0
,
ik
))
>
{}];
});
});
...
@@ -823,6 +868,13 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
...
@@ -823,6 +868,13 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
make_tuple
(
Number
<
(
m0
+
1
)
%
2
>
{},
I0
,
I0
,
k0
,
I0
,
I0
),
make_tuple
(
Number
<
(
m0
+
1
)
%
2
>
{},
I0
,
I0
,
k0
,
I0
,
I0
),
a_thread_buf
);
a_thread_buf
);
});
});
// B VGPR->VGPR dequant
b_thread_dequant_copy_
.
Run
(
b_block_desc_n0_n1_k0_k1
,
b_block_origin_idx
,
b_thread_bufs
(
I0
),
b_thread_desc_
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
b_thread_dequant_bufs
(
I0
));
EpilogueScheduler_2
();
EpilogueScheduler_2
();
}
}
...
@@ -855,6 +907,22 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
...
@@ -855,6 +907,22 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
static
constexpr
BTileDesc
b_block_desc_n0_n1_k0_k1
;
static
constexpr
BTileDesc
b_block_desc_n0_n1_k0_k1
;
using
Base
::
c_thread_desc_
;
using
Base
::
c_thread_desc_
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
BThreadDequantCopy
=
ThreadwiseTensorSliceTransfer_StaticToStatic
<
BDataType
,
ComputeDataType
,
decltype
(
b_block_desc_n0_n1_k0_k1
),
decltype
(
b_block_desc_n0_n1_k0_k1
),
tensor_operation
::
element_wise
::
PassThrough
,
Sequence
<
Number
<
NRepeat
>
{},
I1
,
Number
<
KRepeat
>
{},
Number
<
KPack
>
{}
>
,
Sequence
<
1
,
2
,
0
,
3
>
,
3
,
KPack
>
;
const
PassThrough
b_element_op
{};
BThreadDequantCopy
b_thread_dequant_copy_
{
b_element_op
};
};
};
}
// namespace ck
}
// namespace ck
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment