Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
6d0e78bd
"docs/vscode:/vscode.git/clone" did not exist on "5710567ce3dc1f9aab32a910d79d06a88f95f56d"
Commit
6d0e78bd
authored
Oct 22, 2024
by
Jing Zhang
Browse files
improve weight layout
parent
5d42067e
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
22 additions
and
37 deletions
+22
-37
example/01_gemm/gemm_xdl_fp16_pk_i4_v3.cpp
example/01_gemm/gemm_xdl_fp16_pk_i4_v3.cpp
+12
-27
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
...or_operation/gpu/element/unary_element_wise_operation.hpp
+4
-1
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp
...nsor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp
+6
-9
No files found.
example/01_gemm/gemm_xdl_fp16_pk_i4_v3.cpp
View file @
6d0e78bd
...
@@ -40,8 +40,7 @@ using DeviceGemmV2Instance =
...
@@ -40,8 +40,7 @@ using DeviceGemmV2Instance =
1, 1, S<1, 16, 1, 4>, 4,
1, 1, S<1, 16, 1, 4>, 4,
ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1>;
ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1>;
static int NPerBlock = 16;
[[maybe_unused]] static int KPerBlock = 256;
static int KPerBlock = 256;
#else
#else
128
,
128
,
16
,
32
,
16
,
32
,
...
@@ -53,10 +52,9 @@ using DeviceGemmV2Instance =
...
@@ -53,10 +52,9 @@ using DeviceGemmV2Instance =
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
32
,
32
,
0
,
2
,
32
,
32
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
,
ck
::
BlockGemmPipelineScheduler
::
Interwave
,
ck
::
BlockGemmPipelineVersion
::
v
1
>
;
ck
::
BlockGemmPipelineScheduler
::
Interwave
,
ck
::
BlockGemmPipelineVersion
::
v
2
>
;
static
int
NPerBlock
=
32
;
[[
maybe_unused
]]
static
int
KPerBlock
=
128
;
static
int
KPerBlock
=
128
;
#endif
#endif
// clang-format on
// clang-format on
...
@@ -125,7 +123,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
...
@@ -125,7 +123,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
b_k_n
.
GenerateTensorValue
(
GeneratorTensor_1
<
BDataType
>
{
1
});
b_k_n
.
GenerateTensorValue
(
GeneratorTensor_1
<
BDataType
>
{
1
});
break
;
break
;
case
1
:
case
1
:
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_
2
<
ADataType
>
{
-
2
,
2
});
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_
3
<
ADataType
>
{
0.0
,
1.0
});
b_k_n
.
GenerateTensorValue
(
GeneratorTensor_2
<
BDataType
>
{
-
2
,
2
});
b_k_n
.
GenerateTensorValue
(
GeneratorTensor_2
<
BDataType
>
{
-
2
,
2
});
break
;
break
;
case
2
:
case
2
:
...
@@ -153,31 +151,18 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
...
@@ -153,31 +151,18 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
DeviceMem
c_m_n_device_buf
(
sizeof
(
CDataType
)
*
c_m_n_device_result
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
c_m_n_device_buf
(
sizeof
(
CDataType
)
*
c_m_n_device_result
.
mDesc
.
GetElementSpaceSize
());
// weight permute
// weight permute
#if 0
#if 1
int N1 = NPerBlock;
int
K1
=
KPerBlock
;
int
K1
=
KPerBlock
;
int
K0
=
K
/
KPerBlock
;
int N0 = N / N1;
// int K0, N, K1
int K0 = K / K1;
for
(
int
j
=
0
;
j
<
K0
;
j
++
)
int K01 = K0 / KBatch;
int K00 = KBatch;
std::cout << "K00 = " << K00 << " K01 = " << K01 << std::endl;
for(int k = 0; k < K00; k++)
{
{
for(int i = 0; i < N
0
; i++)
for
(
int
i
=
0
;
i
<
N
;
i
++
)
{
{
for(int j = 0; j < K
0
1; j++)
for
(
int
j
j
=
0
;
j
j
<
K1
;
j
j
++
)
{
{
for(int ii = 0; ii < N1; ii++)
b_k_n_permute
(
j
*
N
*
K1
+
i
*
K1
+
jj
)
=
b_k_n
(
i
*
K
+
(
j
*
K1
+
jj
));
{
for(int jj = 0; jj < K1; jj++)
{
b_k_n_permute(k * N0 * K01 * N1 * K1 + i * K01 * N1 * K1 + j * N1 * K1 + ii * K1 + jj) =
b_k_n((i * N1 + ii) * K + (k * K01 * K1 + j * K1 + jj));
}
}
}
}
}
}
}
}
...
@@ -286,7 +271,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
...
@@ -286,7 +271,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
ref_invoker
.
Run
(
ref_argument
);
ref_invoker
.
Run
(
ref_argument
);
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
false
,
1
});
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
false
,
0
});
c_m_n_device_buf
.
FromDevice
(
c_m_n_device_result
.
mData
.
data
());
c_m_n_device_buf
.
FromDevice
(
c_m_n_device_result
.
mData
.
data
());
pass
&=
ck
::
utils
::
check_err
(
c_m_n_device_result
,
pass
&=
ck
::
utils
::
check_err
(
c_m_n_device_result
,
...
...
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
View file @
6d0e78bd
...
@@ -50,7 +50,7 @@ __host__ __device__ inline half2_t pki4_to_half2(pk_i4_t q)
...
@@ -50,7 +50,7 @@ __host__ __device__ inline half2_t pki4_to_half2(pk_i4_t q)
auto h_f16 = ck::type_convert<ck::half_t>(x_h - 8);
auto h_f16 = ck::type_convert<ck::half_t>(x_h - 8);
return {h_f16, l_f16};
return {h_f16, l_f16};
#el
se
#el
if
1
uint8_t
x_u8
=
ck
::
bit_cast
<
uint8_t
>
(
q
);
uint8_t
x_u8
=
ck
::
bit_cast
<
uint8_t
>
(
q
);
int
x_l
=
(
x_u8
&
0x0f
);
int
x_l
=
(
x_u8
&
0x0f
);
...
@@ -62,6 +62,9 @@ __host__ __device__ inline half2_t pki4_to_half2(pk_i4_t q)
...
@@ -62,6 +62,9 @@ __host__ __device__ inline half2_t pki4_to_half2(pk_i4_t q)
int
lo
=
(
x_l
|
x_h
)
|
EX
;
int
lo
=
(
x_l
|
x_h
)
|
EX
;
return
amd_assembly_pk_add_f16
(
bit_cast
<
half2_t
>
(
lo
),
bit_cast
<
half2_t
>
(
SUB
));
return
amd_assembly_pk_add_f16
(
bit_cast
<
half2_t
>
(
lo
),
bit_cast
<
half2_t
>
(
SUB
));
#else
int32_t
res
=
bit_cast
<
int8_t
>
(
q
);
return
bit_cast
<
half2_t
>
(
res
);
#endif
#endif
}
}
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp
View file @
6d0e78bd
...
@@ -14,7 +14,7 @@
...
@@ -14,7 +14,7 @@
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
//
#define WEIGHT_PERMUTE
#define WEIGHT_PERMUTE
namespace
ck
{
namespace
ck
{
...
@@ -399,22 +399,19 @@ struct GridwiseGemm_xdl_cshuffle_v3
...
@@ -399,22 +399,19 @@ struct GridwiseGemm_xdl_cshuffle_v3
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
#else
#else
constexpr
index_t
N1
=
NPerBlock
;
constexpr
index_t
BK01
=
KPerBlock
/
BK1Value
;
constexpr
index_t
BK01
=
KPerBlock
/
BK1Value
;
const
index_t
BK00
=
BK0
/
BK01
;
const
index_t
BK00
=
BK0
/
BK01
;
const
index_t
N0
=
N
/
N1
;
const
auto
b_grid_desc_
n0_
bk00_n
1
_bk01_bk1
=
const
auto
b_grid_desc_bk00_n_bk01_bk1
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N0
,
BK00
,
N
1
,
BK01
,
BK1Value
));
make_naive_tensor_descriptor_packed
(
make_tuple
(
BK00
,
N
,
BK01
,
BK1Value
));
const
auto
b_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
const
auto
b_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b_grid_desc_
n0_
bk00_n
1
_bk01_bk1
,
b_grid_desc_bk00_n_bk01_bk1
,
make_tuple
(
make_merge_transform
(
make_tuple
(
BK00
,
BK01
)),
make_tuple
(
make_merge_transform
(
make_tuple
(
BK00
,
BK01
)),
make_
merge
_transform
(
make_tuple
(
N
0
,
N1
)),
make_
pass_through
_transform
(
make_tuple
(
N
)),
make_pass_through_transform
(
BK1Value
)),
make_pass_through_transform
(
BK1Value
)),
make_tuple
(
Sequence
<
1
,
3
>
{},
Sequence
<
0
,
2
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
#endif
#endif
return
b_grid_desc_bk0_n_bk1
;
return
b_grid_desc_bk0_n_bk1
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment