Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
e0391df7
Commit
e0391df7
authored
Feb 11, 2025
by
mtgu0705
Browse files
Added gemm_fp8xint4_Bpreshuffle files, function not checked yet
parent
2559ef64
Changes
8
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
2437 additions
and
16 deletions
+2437
-16
example/01_gemm/gemm_xdl_fp8_pk_i4_bpreshuffle_v3.cpp
example/01_gemm/gemm_xdl_fp8_pk_i4_bpreshuffle_v3.cpp
+8
-8
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_selector.hpp
.../blockwise_gemm_pipeline_xdlops_b_preshuffle_selector.hpp
+2
-1
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v1.hpp
.../block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v1.hpp
+1
-4
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v2.hpp
.../block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v2.hpp
+1
-1
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v3.hpp
.../block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v3.hpp
+1
-1
include/ck/tensor_operation/gpu/device/device_gemm_v2.hpp
include/ck/tensor_operation/gpu/device/device_gemm_v2.hpp
+33
-1
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_b_preshuffle.hpp
.../device/impl/device_gemm_xdl_cshuffle_v3_b_preshuffle.hpp
+514
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp
...n/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp
+1877
-0
No files found.
example/01_gemm/gemm_xdl_fp8_pk_i4_bpreshuffle_v3.cpp
View file @
e0391df7
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 202
5
, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
#include "common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3
_b_preshuffle
.hpp"
using
F8
=
ck
::
f8_t
;
using
F8
=
ck
::
f8_t
;
using
I4
=
ck
::
pk_i4_t
;
using
I4
=
ck
::
pk_i4_t
;
...
@@ -63,7 +63,7 @@ static constexpr ck::index_t KPerBlock = 128;
...
@@ -63,7 +63,7 @@ static constexpr ck::index_t KPerBlock = 128;
// clang-format off
// clang-format off
using
DeviceGemmV2Instance
=
using
DeviceGemmV2Instance
=
ck
::
tensor_operation
::
device
::
DeviceGemm_Xdl_CShuffleV3
<
ck
::
tensor_operation
::
device
::
DeviceGemm_Xdl_CShuffleV3
_BPreshuffle
<
ALayout
,
BLayout
,
CLayout
,
ALayout
,
BLayout
,
CLayout
,
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
CShuffleDataType
,
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
CShuffleDataType
,
AElementOp
,
BElementOp
,
CElementOp
,
GemmDefault
,
AElementOp
,
BElementOp
,
CElementOp
,
GemmDefault
,
...
@@ -77,7 +77,7 @@ using DeviceGemmV2Instance =
...
@@ -77,7 +77,7 @@ using DeviceGemmV2Instance =
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
32
,
32
,
0
,
2
,
32
,
32
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
,
ck
::
BlockGemmPipelineScheduler
::
Intrawave
,
ck
::
BlockGemmPipelineVersion
::
v1
,
ADataType
,
ADataType
,
PermuteA
,
PermuteB
>
;
ck
::
BlockGemmPipelineScheduler
::
Intrawave
,
ck
::
BlockGemmPipelineVersion
::
v1
,
F8
,
F8
,
PermuteA
,
PermuteB
>
;
// clang-format on
// clang-format on
...
@@ -174,8 +174,8 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
...
@@ -174,8 +174,8 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
DeviceMem
b_k_n_device_buf
(
sizeof
(
BDataType
)
*
b_k_n_permute
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
b_k_n_device_buf
(
sizeof
(
BDataType
)
*
b_k_n_permute
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
c_m_n_device_buf
(
sizeof
(
CDataType
)
*
c_m_n_device_result
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
c_m_n_device_buf
(
sizeof
(
CDataType
)
*
c_m_n_device_result
.
mDesc
.
GetElementSpaceSize
());
//
int NperXdl=
16
;
int
NperXdl
=
GetPreShuffleParameters
;
//
preShuffleBuffer(b_k_n.mData.data(), b_k_n_preshuffled.mData.data(), N, K, NperXdl);
preShuffleBuffer
(
b_k_n
.
mData
.
data
(),
b_k_n_preshuffled
.
mData
.
data
(),
N
,
K
,
NperXdl
);
// weight permute
// weight permute
if
constexpr
(
PermuteB
)
if
constexpr
(
PermuteB
)
...
@@ -190,7 +190,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
...
@@ -190,7 +190,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
{
{
for
(
int
jj
=
0
;
jj
<
K1
;
jj
++
)
for
(
int
jj
=
0
;
jj
<
K1
;
jj
++
)
{
{
b_k_n_permute
(
j
*
N
*
K1
+
i
*
K1
+
jj
)
=
b_k_n
(
i
*
K
+
(
j
*
K1
+
jj
));
b_k_n_permute
(
j
*
N
*
K1
+
i
*
K1
+
jj
)
=
b_k_n
_preshuffled
(
i
*
K
+
(
j
*
K1
+
jj
));
}
}
}
}
}
}
...
@@ -201,7 +201,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
...
@@ -201,7 +201,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
{
{
for
(
int
j
=
0
;
j
<
K
;
j
++
)
for
(
int
j
=
0
;
j
<
K
;
j
++
)
{
{
b_k_n_permute
(
i
*
K
+
j
)
=
b_k_n
(
i
*
K
+
j
);
b_k_n_permute
(
i
*
K
+
j
)
=
b_k_n
_preshuffled
(
i
*
K
+
j
);
}
}
}
}
}
}
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_selector.hpp
View file @
e0391df7
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v1.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v1.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v2.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v2.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v3.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v3.hpp"
namespace
ck
{
namespace
ck
{
template
<
BlockGemmPipelineVersion
BlkGemmPipelineVer
,
template
<
BlockGemmPipelineVersion
BlkGemmPipelineVer
,
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v1.hpp
View file @
e0391df7
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -305,9 +305,6 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I
...
@@ -305,9 +305,6 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
,
local_read_buf
);
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
,
local_read_buf
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
// printf("bid %d tid %d %f %f\n", blockIdx.x, threadIdx.x,
// type_convert<float>(a_thread_buf[I0]),
// type_convert<float>(b_thread_bufs[mfma_reg_buf][I0]));
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v2.hpp
View file @
e0391df7
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v3.hpp
View file @
e0391df7
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
...
include/ck/tensor_operation/gpu/device/device_gemm_v2.hpp
View file @
e0391df7
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -114,6 +114,38 @@ struct DeviceGemmV2BScale : public BaseOperator
...
@@ -114,6 +114,38 @@ struct DeviceGemmV2BScale : public BaseOperator
virtual
ck
::
index_t
GetKPerBlock
()
=
0
;
virtual
ck
::
index_t
GetKPerBlock
()
=
0
;
};
};
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
,
typename
ADataType
,
typename
BDataType
,
typename
CDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
>
struct
DeviceGemmV2BPreshuffle
:
public
BaseOperator
{
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
void
*
p_c
,
ck
::
index_t
M
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
StrideA
,
ck
::
index_t
StrideB
,
ck
::
index_t
StrideC
,
ck
::
index_t
KSplit
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
virtual
bool
GetPermuteB
()
=
0
;
virtual
ck
::
index_t
GetKPerBlock
()
=
0
;
};
}
// namespace device
}
// namespace device
}
// namespace tensor_operation
}
// namespace tensor_operation
}
// namespace ck
}
// namespace ck
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_b_preshuffle.hpp
0 → 100644
View file @
e0391df7
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp
0 → 100644
View file @
e0391df7
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment