Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
43fd4742
Commit
43fd4742
authored
Jan 09, 2025
by
mtgu0705
Browse files
Created a branch int4_pr_based_on_JingPR
parent
1d8e4ec2
Changes
12
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
6752 additions
and
0 deletions
+6752
-0
example/01_gemm/CMakeLists.txt
example/01_gemm/CMakeLists.txt
+1
-0
example/01_gemm/gemm_xdl_fp16_pk_i4_v3_b_scale.cpp
example/01_gemm/gemm_xdl_fp16_pk_i4_v3_b_scale.cpp
+425
-0
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_scale_selector.hpp
...block/blockwise_gemm_pipeline_xdlops_b_scale_selector.hpp
+167
-0
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_b_scale.hpp
...n/gpu/block/blockwise_gemm_pipeline_xdlops_v1_b_scale.hpp
+403
-0
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2_b_scale.hpp
...n/gpu/block/blockwise_gemm_pipeline_xdlops_v2_b_scale.hpp
+1248
-0
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_b_scale.hpp
...n/gpu/block/blockwise_gemm_pipeline_xdlops_v3_b_scale.hpp
+530
-0
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4_b_scale.hpp
...n/gpu/block/blockwise_gemm_pipeline_xdlops_v4_b_scale.hpp
+686
-0
include/ck/tensor_operation/gpu/device/device_gemm_v2.hpp
include/ck/tensor_operation/gpu/device/device_gemm_v2.hpp
+35
-0
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_b_scale.hpp
...n/gpu/device/impl/device_gemm_xdl_cshuffle_v3_b_scale.hpp
+780
-0
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
...or_operation/gpu/element/unary_element_wise_operation.hpp
+65
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp
...ration/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp
+2208
-0
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
...operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
+204
-0
No files found.
example/01_gemm/CMakeLists.txt
View file @
43fd4742
...
...
@@ -30,6 +30,7 @@ add_example_executable(example_gemm_xdl_fp8_v3 gemm_xdl_fp8_v3.cpp)
add_example_dependencies
(
example_gemm_xdl example_gemm_xdl_fp8_v3
)
add_example_executable
(
example_gemm_xdl_fp16_fp8_v3 gemm_xdl_fp16_fp8_v3.cpp
)
add_example_executable
(
example_gemm_xdl_fp16_pk_i4_v3 gemm_xdl_fp16_pk_i4_v3.cpp
)
add_example_executable
(
example_gemm_xdl_fp16_pk_i4_v3_b_scale gemm_xdl_fp16_pk_i4_v3_b_scale.cpp
)
add_example_executable
(
example_gemm_xdl_bf16_pk_i4_v3 gemm_xdl_bf16_pk_i4_v3.cpp
)
add_example_dependencies
(
example_gemm_xdl example_gemm_xdl_fp16_fp8_v3
)
add_example_executable
(
example_gemm_xdl_bf16_v3 gemm_xdl_bf16_v3.cpp
)
...
...
example/01_gemm/gemm_xdl_fp16_pk_i4_v3_b_scale.cpp
0 → 100644
View file @
43fd4742
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_b_scale.hpp"
using
ADataType
=
ck
::
half_t
;
using
BDataType
=
ck
::
pk_i4_t
;
using
BScaleDataType
=
ck
::
half_t
;
using
AccDataType
=
float
;
using
CShuffleDataType
=
ck
::
half_t
;
using
CDataType
=
ck
::
half_t
;
using
ALayout
=
Row
;
using
BLayout
=
Col
;
using
CLayout
=
Row
;
using
AElementOp
=
PassThrough
;
using
BElementOp
=
PassThrough
;
using
CElementOp
=
PassThrough
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
static
constexpr
bool
PermuteB
=
true
;
static
constexpr
ck
::
index_t
Scale_Block_N
=
1
;
static
constexpr
ck
::
index_t
Scale_Block_K
=
128
;
static
constexpr
ck
::
index_t
KPerBlock
=
64
;
// clang-format off
using
DeviceGemmV2Instance
=
ck
::
tensor_operation
::
device
::
DeviceGemm_Xdl_CShuffleV3
<
ALayout
,
BLayout
,
CLayout
,
ADataType
,
BDataType
,
BScaleDataType
,
CDataType
,
AccDataType
,
CShuffleDataType
,
AElementOp
,
BElementOp
,
CElementOp
,
GemmDefault
,
256
,
Scale_Block_N
,
Scale_Block_K
,
128
,
128
,
KPerBlock
,
8
,
32
,
32
,
32
,
4
,
1
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
2
,
128
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
32
,
32
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
ck
::
BlockGemmPipelineScheduler
::
Intrawave
,
ck
::
BlockGemmPipelineVersion
::
v3
,
CDataType
,
CDataType
,
false
,
PermuteB
>
;
// clang-format on
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
AccDataType
,
CDataType
,
AccDataType
,
PassThrough
,
PassThrough
,
PassThrough
>
;
template
<
typename
ProblemType
>
bool
run_gemm
(
const
ProblemType
&
problem_size
,
const
ExecutionConfig
&
config
)
{
using
namespace
ck
::
literals
;
auto
M
=
problem_size
.
M
;
auto
N
=
problem_size
.
N
;
auto
K
=
problem_size
.
K
;
auto
StrideA
=
problem_size
.
StrideA
;
auto
StrideB
=
problem_size
.
StrideB
;
auto
StrideC
=
problem_size
.
StrideC
;
auto
KBatch
=
problem_size
.
KBatch
;
auto
f_host_tensor_descriptor
=
[](
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
,
auto
layout
)
{
if
constexpr
(
std
::
is_same_v
<
decltype
(
layout
),
ck
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
HostTensorDescriptor
({
row
,
col
},
{
stride
,
1
_uz
});
}
else
{
return
HostTensorDescriptor
({
row
,
col
},
{
1
_uz
,
stride
});
}
};
auto
f_get_default_stride
=
[](
std
::
size_t
row
,
std
::
size_t
col
,
ck
::
index_t
stride
,
auto
layout
)
{
if
(
stride
==
-
1
)
{
// give a chance if stride is -1, return a default packed stride
if
constexpr
(
std
::
is_same_v
<
decltype
(
layout
),
ck
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
static_cast
<
std
::
size_t
>
(
col
);
}
else
{
return
static_cast
<
std
::
size_t
>
(
row
);
}
}
else
return
static_cast
<
std
::
size_t
>
(
stride
);
};
ck
::
index_t
Scale_Stride_BN
=
(
K
+
Scale_Block_K
-
1
)
/
Scale_Block_K
;
StrideA
=
f_get_default_stride
(
M
,
K
,
StrideA
,
ALayout
{});
StrideB
=
f_get_default_stride
(
K
,
N
,
StrideB
,
BLayout
{});
StrideC
=
f_get_default_stride
(
M
,
N
,
StrideC
,
CLayout
{});
Tensor
<
ADataType
>
a_m_k
(
f_host_tensor_descriptor
(
M
,
K
,
StrideA
,
ALayout
{}));
Tensor
<
BDataType
>
b_k_n
(
f_host_tensor_descriptor
(
K
,
N
,
StrideB
,
BLayout
{}));
Tensor
<
BDataType
>
b_k_n_permute
(
f_host_tensor_descriptor
(
K
,
N
,
StrideB
,
BLayout
{}));
Tensor
<
BScaleDataType
>
b1_k_n
(
f_host_tensor_descriptor
((
K
+
Scale_Block_K
-
1
)
/
Scale_Block_K
,
(
N
+
Scale_Block_N
-
1
)
/
Scale_Block_N
,
Scale_Stride_BN
,
BLayout
{}));
switch
(
config
.
init_method
)
{
case
0
:
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_1
<
ADataType
>
{
1
});
b_k_n
.
GenerateTensorValue
(
GeneratorTensor_1
<
BDataType
>
{
1
});
b1_k_n
.
GenerateTensorValue
(
GeneratorTensor_1
<
BScaleDataType
>
{
1
});
break
;
case
1
:
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_2
<
ADataType
>
{
-
2
,
2
});
b_k_n
.
GenerateTensorValue
(
GeneratorTensor_2
<
BDataType
>
{
-
2
,
2
});
b1_k_n
.
GenerateTensorValue
(
GeneratorTensor_3
<
BScaleDataType
>
{
0
,
1.0
});
break
;
case
2
:
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_1
<
ADataType
>
{
1
});
b_k_n
.
GenerateTensorValue
(
GeneratorTensor_2
<
BDataType
>
{
-
2
,
2
});
b1_k_n
.
GenerateTensorValue
(
GeneratorTensor_1
<
BScaleDataType
>
{
1
});
break
;
case
3
:
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_2
<
ADataType
>
{
-
2
,
2
});
b_k_n
.
GenerateTensorValue
(
GeneratorTensor_1
<
BDataType
>
{
1
});
b1_k_n
.
GenerateTensorValue
(
GeneratorTensor_1
<
BScaleDataType
>
{
1
});
break
;
case
4
:
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_1
<
ADataType
>
{
1
});
b_k_n
.
GenerateTensorValue
(
GeneratorTensor_1
<
BDataType
>
{
1
});
b1_k_n
.
GenerateTensorValue
(
GeneratorTensor_3
<
BScaleDataType
>
{
0
,
1.0
});
break
;
case
5
:
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_2
<
ADataType
>
{
-
2
,
2
});
b_k_n
.
GenerateTensorValue
(
GeneratorTensor_2
<
BDataType
>
{
-
2
,
2
});
b1_k_n
.
GenerateTensorValue
(
GeneratorTensor_1
<
BScaleDataType
>
{
1
});
break
;
default:
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_3
<
ADataType
>
{
0.5
,
0.5
});
b_k_n
.
GenerateTensorValue
(
GeneratorTensor_2
<
BDataType
>
{
-
2
,
2
});
b1_k_n
.
GenerateTensorValue
(
GeneratorTensor_3
<
BScaleDataType
>
{
0
,
1.0
});
}
Tensor
<
CDataType
>
c_m_n_host_result
(
f_host_tensor_descriptor
(
M
,
N
,
StrideC
,
CLayout
{}));
Tensor
<
CDataType
>
c_m_n_device_result
(
f_host_tensor_descriptor
(
M
,
N
,
StrideC
,
CLayout
{}));
std
::
cout
<<
"a_m_k: "
<<
a_m_k
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"b_k_n: "
<<
b_k_n
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"b1_k_n: "
<<
b1_k_n
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"c_m_n: "
<<
c_m_n_host_result
.
mDesc
<<
std
::
endl
;
DeviceMem
a_m_k_device_buf
(
sizeof
(
ADataType
)
*
a_m_k
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
b_k_n_device_buf
(
sizeof
(
BDataType
)
*
b_k_n_permute
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
b1_scale_device_buf
(
sizeof
(
BScaleDataType
)
*
b1_k_n
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
c_m_n_device_buf
(
sizeof
(
CDataType
)
*
c_m_n_device_result
.
mDesc
.
GetElementSpaceSize
());
printf
(
"b_k_n element space size: %zu, b_k_n device size: %lu, BDataType size: %lu
\n
"
,
b_k_n_permute
.
mDesc
.
GetElementSpaceSize
(),
sizeof
(
BDataType
)
*
b_k_n_permute
.
mDesc
.
GetElementSpaceSize
(),
sizeof
(
BDataType
));
// weight permute
if
constexpr
(
PermuteB
)
{
int
K1
=
KPerBlock
;
int
K0
=
K
/
KPerBlock
;
// int K0, N, K1
for
(
int
j
=
0
;
j
<
K0
;
j
++
)
{
for
(
int
i
=
0
;
i
<
N
;
i
++
)
{
for
(
int
jj
=
0
;
jj
<
K1
;
jj
++
)
{
b_k_n_permute
(
j
*
N
*
K1
+
i
*
K1
+
jj
)
=
b_k_n
(
i
*
K
+
(
j
*
K1
+
jj
));
}
}
}
}
else
{
for
(
int
i
=
0
;
i
<
N
;
i
++
)
{
for
(
int
j
=
0
;
j
<
K
;
j
++
)
{
b_k_n_permute
(
i
*
K
+
j
)
=
b_k_n
(
i
*
K
+
j
);
}
}
}
// vector pk_i4x4 permute
#if 1
for
(
int
i
=
0
;
i
<
N
;
i
++
)
{
for
(
int
j
=
0
;
j
<
K
;
j
+=
8
)
{
int
input
[
8
];
for
(
int
k
=
0
;
k
<
4
;
k
++
)
{
int
i4x2
=
b_k_n_permute
(
j
+
k
*
2
,
i
).
data
;
input
[
k
*
2
+
0
]
=
(
i4x2
>>
4
)
&
0xf
;
input
[
k
*
2
+
1
]
=
(
i4x2
>>
0
)
&
0xf
;
}
// permute 01234567->20643175
{
int
hi
=
input
[
2
];
int
lo
=
input
[
0
];
int
i4x2
=
(
hi
<<
4
)
|
lo
;
b_k_n_permute
(
j
+
0
,
i
)
=
i4x2
;
}
{
int
hi
=
input
[
6
];
int
lo
=
input
[
4
];
int
i4x2
=
(
hi
<<
4
)
|
lo
;
b_k_n_permute
(
j
+
2
,
i
)
=
i4x2
;
}
{
int
hi
=
input
[
3
];
int
lo
=
input
[
1
];
int
i4x2
=
(
hi
<<
4
)
|
lo
;
b_k_n_permute
(
j
+
4
,
i
)
=
i4x2
;
}
{
int
hi
=
input
[
7
];
int
lo
=
input
[
5
];
int
i4x2
=
(
hi
<<
4
)
|
lo
;
b_k_n_permute
(
j
+
6
,
i
)
=
i4x2
;
}
}
}
#endif
a_m_k_device_buf
.
ToDevice
(
a_m_k
.
mData
.
data
());
b_k_n_device_buf
.
ToDevice
(
b_k_n_permute
.
mData
.
data
());
b1_scale_device_buf
.
ToDevice
(
b1_k_n
.
mData
.
data
());
DeviceMem
workspace
;
auto
a_element_op
=
AElementOp
{};
auto
b_element_op
=
BElementOp
{};
auto
c_element_op
=
CElementOp
{};
// do GEMM
auto
gemm
=
DeviceGemmV2Instance
{};
auto
invoker
=
gemm
.
MakeInvoker
();
float
ave_time
=
0
;
auto
argument
=
gemm
.
MakeArgument
(
static_cast
<
ADataType
*>
(
a_m_k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
BDataType
*>
(
b_k_n_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CDataType
*>
(
c_m_n_device_buf
.
GetDeviceBuffer
()),
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
,
Scale_Stride_BN
,
static_cast
<
BScaleDataType
*>
(
b1_scale_device_buf
.
GetDeviceBuffer
()),
KBatch
,
a_element_op
,
b_element_op
,
c_element_op
);
if
(
!
gemm
.
IsSupportedArgument
(
argument
))
{
std
::
cerr
<<
gemm
.
GetTypeString
()
<<
" does not support this problem"
<<
std
::
endl
;
return
true
;
}
std
::
size_t
workspace_size
=
gemm
.
GetWorkSpaceSize
(
&
argument
);
printf
(
"workspace_size: %zu
\n
"
,
workspace_size
);
bool
pass
=
true
;
if
(
config
.
do_verification
)
{
Tensor
<
float
>
b_k_n_dequant
({
K
,
N
});
float
v_b
=
0
;
for
(
int
n
=
0
;
n
<
N
;
n
++
)
{
for
(
int
k
=
0
;
k
<
K
;
k
++
)
{
ck
::
pk_i4_t
i4x2
=
b_k_n
(
k
,
n
);
int8_t
i4
=
0
;
if
(
k
%
2
==
1
)
i4
=
(
i4x2
.
data
>>
0
)
&
0xf
;
else
i4
=
(
i4x2
.
data
>>
4
)
&
0xf
;
i4
=
i4
-
8
;
v_b
=
ck
::
type_convert
<
float
>
(
i4
);
b_k_n_dequant
(
k
,
n
)
=
ck
::
type_convert
<
float
>
(
v_b
)
*
ck
::
type_convert
<
float
>
(
b1_k_n
(
k
/
Scale_Block_K
,
n
/
Scale_Block_N
));
}
}
auto
ref_gemm
=
ReferenceGemmInstance
{};
auto
ref_invoker
=
ref_gemm
.
MakeInvoker
();
auto
ref_argument
=
ref_gemm
.
MakeArgument
(
a_m_k
,
b_k_n_dequant
,
c_m_n_host_result
,
PassThrough
{},
PassThrough
{},
PassThrough
{});
ref_invoker
.
Run
(
ref_argument
);
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
false
,
0
});
c_m_n_device_buf
.
FromDevice
(
c_m_n_device_result
.
mData
.
data
());
pass
&=
ck
::
utils
::
check_err
(
c_m_n_device_result
,
c_m_n_host_result
,
"Error: Incorrect results!"
,
get_rtol
<
CDataType
>
(),
get_atol
<
CDataType
>
());
#if 0
std::cout << "a_m_k: " << std::endl;
for(int i = 0; i < M; i++)
{
for(int j = 0; j < K; j++)
{
std::cout << ck::type_convert<float>(a_m_k(i, j)) << ",";
}
std::cout << std::endl;
}
std::cout << "b_k_n: " << std::endl;
for(int i = 0; i < N; i++)
{
for(int j = 0; j < K; j++)
{
ck::pk_i4_t i4x2 = b_k_n(j, i);
int8_t i4 = 0;
if( j % 2 == 1)
i4 = (i4x2 >> 0) & 0xf;
else
i4 = (i4x2 >> 4) & 0xf;
i4 = i4 - 8;
std::cout << ck::type_convert<float>(i4) << ",";
}
std::cout << std::endl;
}
std::cout<<"scale_b1_k_n: "<<std::endl;
for(int i = 0; i < N; i++)
{
for(int j = 0; j < (K + Scale_Block_K - 1) / Scale_Block_K; j++)
{
std::cout << ck::type_convert<float>(b1_k_n(j,i)) << ",";
}
std::cout << std::endl;
}
std::cout << "c_m_n_device_result: " << std::endl;
for(int i = 0; i < M; i++)
{
for(int j = 0; j < N; j++)
{
std::cout << ck::type_convert<float>(c_m_n_device_result(i, j)) << ",";
}
std::cout << std::endl;
}
std::cout << "c_m_n_host_result: " << std::endl;
for(int i = 0; i < M; i++)
{
for(int j = 0; j < N; j++)
{
std::cout << ck::type_convert<float>(c_m_n_host_result(i, j)) << ",";
}
std::cout << std::endl;
}
#endif
}
if
(
config
.
time_kernel
)
{
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
config
.
time_kernel
,
0
,
20
,
50
,
true
,
50
});
std
::
size_t
flop
=
2
_uz
*
M
*
N
*
K
;
std
::
size_t
num_btype
=
sizeof
(
ADataType
)
*
M
*
K
+
sizeof
(
BDataType
)
*
K
*
N
/
(
ck
::
is_same_v
<
ck
::
remove_cvref_t
<
BDataType
>
,
ck
::
pk_i4_t
>
?
2
:
1
)
+
sizeof
(
CDataType
)
*
M
*
N
;
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
float
gb_per_sec
=
num_btype
/
1.E6
/
ave_time
;
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
gemm
.
GetTypeString
()
<<
std
::
endl
;
}
return
pass
;
}
bool
run_gemm_splitk_example
(
int
argc
,
char
*
argv
[])
{
ProblemSizeSplitK
problem_size
;
ExecutionConfig
config
;
return
!
parse_cmd_args
(
argc
,
argv
,
problem_size
,
config
)
||
run_gemm
(
problem_size
,
config
);
}
int
main
(
int
argc
,
char
*
argv
[])
{
return
!
run_gemm_splitk_example
(
argc
,
argv
);
}
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_scale_selector.hpp
0 → 100644
View file @
43fd4742
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_b_scale.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2_b_scale.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_b_scale.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4_b_scale.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v5.hpp"
namespace
ck
{
enum
struct
BlockGemmPipelineVersion
{
v1
,
// Naive
v2
,
// Mem
v3
,
// Comp
v4
,
// Comp, double lds buffer
v5
,
// Comp, double global prefetch register buffer
};
template
<
BlockGemmPipelineVersion
BlkGemmPipelineVer
,
BlockGemmPipelineScheduler
BlkGemmPipeSche
,
index_t
BlockSize
,
typename
ADataType
,
typename
BDataType
,
typename
ComputeDataType
,
typename
AccDataType
,
typename
ATileDesc
,
typename
BTileDesc
,
typename
AMmaTileDesc
,
typename
BMmaTileDesc
,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
MPerXDL
,
index_t
NPerXDL
,
index_t
MRepeat
,
index_t
NRepeat
,
index_t
KPack
>
constexpr
auto
BlockGemmPipeline_Selector
()
{
if
constexpr
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v1
)
{
return
BlockwiseGemmXdlops_pipeline_v1_b_scale
<
BlkGemmPipeSche
,
BlockSize
,
ADataType
,
BDataType
,
ComputeDataType
,
AccDataType
,
ATileDesc
,
BTileDesc
,
AMmaTileDesc
,
BMmaTileDesc
,
ABlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXDL
,
NPerXDL
,
MRepeat
,
NRepeat
,
KPack
>
{};
}
else
if
constexpr
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v2
)
{
return
BlockwiseGemmXdlops_pipeline_v2_b_scale
<
BlkGemmPipeSche
,
BlockSize
,
ADataType
,
BDataType
,
ComputeDataType
,
AccDataType
,
ATileDesc
,
BTileDesc
,
AMmaTileDesc
,
BMmaTileDesc
,
ABlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXDL
,
NPerXDL
,
MRepeat
,
NRepeat
,
KPack
>
{};
}
else
if
constexpr
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v3
)
{
return
BlockwiseGemmXdlops_pipeline_v3_b_scale
<
BlkGemmPipeSche
,
BlockSize
,
ADataType
,
BDataType
,
ComputeDataType
,
AccDataType
,
ATileDesc
,
BTileDesc
,
AMmaTileDesc
,
BMmaTileDesc
,
ABlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXDL
,
NPerXDL
,
MRepeat
,
NRepeat
,
KPack
>
{};
}
else
if
constexpr
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v4
)
{
return
BlockwiseGemmXdlops_pipeline_v4_b_scale
<
BlkGemmPipeSche
,
BlockSize
,
ADataType
,
BDataType
,
ComputeDataType
,
AccDataType
,
ATileDesc
,
BTileDesc
,
AMmaTileDesc
,
BMmaTileDesc
,
ABlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXDL
,
NPerXDL
,
MRepeat
,
NRepeat
,
KPack
>
{};
}
else
if
constexpr
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v5
)
{
return
BlockwiseGemmXdlops_pipeline_v5
<
BlkGemmPipeSche
,
BlockSize
,
ADataType
,
BDataType
,
ComputeDataType
,
AccDataType
,
ATileDesc
,
BTileDesc
,
AMmaTileDesc
,
BMmaTileDesc
,
ABlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXDL
,
NPerXDL
,
MRepeat
,
NRepeat
,
KPack
>
{};
}
else
{
std
::
cerr
<<
"BlockGemmPipeline configuration is not available"
<<
std
::
endl
;
}
}
}
// namespace ck
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_b_scale.hpp
0 → 100644
View file @
43fd4742
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp"
namespace
ck
{
// Naive pipeline with lowest resource request per WGP
// GlobalPrefetchStages: 1
// LocalPreFillStages: 1
// LocalPreFetchStages: 0
// LocalSharedMemoryBuffer: 1
template
<
BlockGemmPipelineScheduler
BlkGemmPipelineVer
,
index_t
BlockSize
,
typename
ADataType
,
typename
BDataType
,
typename
ComputeDataType
,
typename
AccDataType
,
typename
ATileDesc
,
typename
BTileDesc
,
typename
AMmaTileDesc
,
typename
BMmaTileDesc
,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
MPerXDL
,
index_t
NPerXDL
,
index_t
MRepeat
,
index_t
NRepeat
,
index_t
KPacks
>
struct
BlockwiseGemmXdlops_pipeline_v1_b_scale
{
};
template
<
index_t
BlockSize
,
typename
ADataType
,
typename
BDataType
,
typename
ComputeDataType
,
typename
AccDataType
,
typename
ATileDesc
,
typename
BTileDesc
,
typename
AMmaTileDesc
,
typename
BMmaTileDesc
,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
MPerXDL
,
index_t
NPerXDL
,
index_t
MRepeat
,
index_t
NRepeat
,
index_t
KPack
// ,bool TransposeC //disable transposec right now...
>
struct
BlockwiseGemmXdlops_pipeline_v1_b_scale
<
BlockGemmPipelineScheduler
::
Intrawave
,
BlockSize
,
ADataType
,
BDataType
,
ComputeDataType
,
AccDataType
,
ATileDesc
,
BTileDesc
,
AMmaTileDesc
,
BMmaTileDesc
,
ABlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXDL
,
NPerXDL
,
MRepeat
,
NRepeat
,
KPack
>
:
BlockwiseGemmXdlops_pipeline_base
<
BlockSize
,
ADataType
,
BDataType
,
ComputeDataType
,
AccDataType
,
ATileDesc
,
BTileDesc
,
AMmaTileDesc
,
BMmaTileDesc
,
ABlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXDL
,
NPerXDL
,
MRepeat
,
NRepeat
,
KPack
>
{
using
Base
=
BlockwiseGemmXdlops_pipeline_base
<
BlockSize
,
ADataType
,
BDataType
,
ComputeDataType
,
AccDataType
,
ATileDesc
,
BTileDesc
,
AMmaTileDesc
,
BMmaTileDesc
,
ABlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXDL
,
NPerXDL
,
MRepeat
,
NRepeat
,
KPack
>
;
using
Base
::
I0
;
using
Base
::
KRepeat
;
using
Base
::
xdlops_gemm
;
using
Base
::
CalculateCThreadOriginDataIndex
;
using
Base
::
CalculateCThreadOriginDataIndex8D
;
using
Base
::
GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
;
using
Base
::
GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
;
using
Base
::
GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
;
using
Base
::
GetCThreadBuffer
;
using
Base
::
GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
;
using
Base
::
GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
;
using
Base
::
GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
;
using
Base
::
MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
;
using
Base
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
;
using
Base
::
a_block_desc_m0_m1_m2_k
;
using
Base
::
b_block_desc_n0_n1_n2_k
;
using
Base
::
AMmaKStride
;
using
Base
::
BMmaKStride
;
static
constexpr
index_t
PrefetchStages
=
1
;
static
constexpr
index_t
PrefillStages
=
1
;
static
constexpr
index_t
GlobalBufferNum
=
1
;
__host__
static
constexpr
bool
BlockHasHotloop
(
index_t
num_loop
)
{
return
num_loop
>
PrefetchStages
;
}
__host__
static
constexpr
TailNumber
BlockLoopTailNum
(
index_t
num_loop
)
{
ignore
=
num_loop
;
return
TailNumber
::
Full
;
}
template
<
bool
HasMainLoop
,
TailNumber
TailNum
,
typename
AGridDesc
,
typename
ABlockDesc
,
typename
ABlockTransfer
,
typename
AGridBuffer
,
typename
ABlockBuffer
,
typename
ABlockTransferStep
,
typename
BGridDesc
,
typename
BBlockDesc
,
typename
BBlockTransfer
,
typename
BGridBuffer
,
typename
BBlockBuffer
,
typename
BBlockTransferStep
,
typename
CThreadBuffer
,
// BScale Thread Copy
typename
BScaleGridBuffer
,
typename
BScaleGridDesc
,
typename
BScaleThreadDesc
,
typename
BScaleThreadTransfer
,
typename
BScaleThreadTransferStep
>
__device__
void
Run
(
// ABlockCopy
const
AGridDesc
&
a_grid_desc
,
const
ABlockDesc
&
a_block_desc
,
ABlockTransfer
&
a_blockwise_copy
,
const
AGridBuffer
&
a_grid_buf
,
ABlockBuffer
&
a_block_buf
,
const
ABlockTransferStep
&
a_block_copy_step
,
// BBlockCopy
const
BGridDesc
&
b_grid_desc
,
const
BBlockDesc
&
b_block_desc
,
BBlockTransfer
&
b_blockwise_copy
,
const
BGridBuffer
&
b_grid_buf
,
BBlockBuffer
&
b_block_buf
,
const
BBlockTransferStep
&
b_block_copy_step
,
// CThread
CThreadBuffer
&
c_thread_buf
,
// BScaleThreadCopy
const
BScaleGridDesc
&
b_scale_grid_desc
,
const
BScaleThreadDesc
&
b_scale_thread_desc
,
BScaleThreadTransfer
&
b_scale_thread_copy
,
const
BScaleGridBuffer
&
b_scale_grid_buf
,
const
BScaleThreadTransferStep
&
b_scale_thread_copy_step
,
// num_loop
index_t
num_loop
,
index_t
num_loop_per_scale
)
const
{
// assume kperblock = scaleblockk
ignore
=
num_loop_per_scale
;
auto
a_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
>
(
a_thread_desc_
.
GetElementSpaceSize
());
auto
b_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
>
(
b_thread_desc_
.
GetElementSpaceSize
());
auto
b_scale_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
>
(
b_scale_thread_desc
.
GetElementSpaceSize
());
// Global prefetch 1
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
b_scale_thread_copy
.
Run
(
b_scale_grid_desc
,
b_scale_grid_buf
,
b_scale_thread_desc
,
make_tuple
(
n0
,
I0
),
b_scale_thread_buf
);
b_scale_thread_copy
.
MoveSrcSliceWindow
(
b_scale_grid_desc
,
b_scale_thread_copy_step
.
At
(
Number
<
0
>
{}));
});
b_scale_thread_copy
.
MoveSrcSliceWindow
(
b_scale_grid_desc
,
b_scale_thread_copy_step
.
At
(
Number
<
1
>
{}));
// Local prefill 1
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
);
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
);
// Initialize C
c_thread_buf
.
Clear
();
auto
c_thread_buf_per_scale
=
remove_cvref_t
<
decltype
(
c_thread_buf
)
>
();
// main body
if
constexpr
(
HasMainLoop
)
{
index_t
i
=
0
;
do
{
// -------------------------------------------------------------------------------------------
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
block_sync_lds
();
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
make_tuple
(
m0
,
I0
,
I0
,
Number
<
k
*
AMmaKStride
>
{}),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
m0
,
I0
,
k
,
I0
),
a_thread_buf
);
});
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
make_tuple
(
n0
,
I0
,
I0
,
Number
<
k
*
BMmaKStride
>
{}),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
n0
,
I0
,
k
,
I0
),
b_thread_buf
);
});
});
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
c_thread_buf_per_scale
.
Clear
();
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
vector_type
<
ComputeDataType
,
KPack
>
a_thread_vec
;
vector_type
<
ComputeDataType
,
KPack
>
b_thread_vec
;
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
a_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
I0
,
k0
,
ik
))
>
{}];
b_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
n0
,
I0
,
k0
,
ik
))
>
{}];
});
using
mfma_input_type
=
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
xdlops_gemm
.
template
Run
<
>(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf_per_scale
.
GetVectorTypeReference
(
I0
));
});
static_for
<
0
,
xdlops_gemm
.
GetRegSizePerXdlops
(),
1
>
{}([
&
](
auto
t
)
{
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
t
));
c_thread_buf
(
Number
<
c_offset
>
{})
+=
c_thread_buf_per_scale
[
Number
<
t
>
{}]
*
type_convert
<
AccDataType
>
(
b_scale_thread_buf
[
n0
]);
});
});
});
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
b_scale_thread_copy
.
Run
(
b_scale_grid_desc
,
b_scale_grid_buf
,
b_scale_thread_desc
,
make_tuple
(
n0
,
I0
),
b_scale_thread_buf
);
b_scale_thread_copy
.
MoveSrcSliceWindow
(
b_scale_grid_desc
,
b_scale_thread_copy_step
.
At
(
Number
<
0
>
{}));
});
b_scale_thread_copy
.
MoveSrcSliceWindow
(
b_scale_grid_desc
,
b_scale_thread_copy_step
.
At
(
Number
<
1
>
{}));
block_sync_lds
();
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
);
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
);
i
+=
1
;
}
while
(
i
<
(
num_loop
-
1
));
}
// tail
if
constexpr
(
TailNum
==
TailNumber
::
Full
)
{
block_sync_lds
();
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
make_tuple
(
m0
,
I0
,
I0
,
Number
<
k
*
AMmaKStride
>
{}),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
m0
,
I0
,
k
,
I0
),
a_thread_buf
);
});
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
make_tuple
(
n0
,
I0
,
I0
,
Number
<
k
*
BMmaKStride
>
{}),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
n0
,
I0
,
k
,
I0
),
b_thread_buf
);
});
});
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
c_thread_buf_per_scale
.
Clear
();
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
vector_type
<
ComputeDataType
,
KPack
>
a_thread_vec
;
vector_type
<
ComputeDataType
,
KPack
>
b_thread_vec
;
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
a_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
I0
,
k0
,
ik
))
>
{}];
b_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
n0
,
I0
,
k0
,
ik
))
>
{}];
});
using
mfma_input_type
=
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
xdlops_gemm
.
template
Run
<
>(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf_per_scale
.
GetVectorTypeReference
(
I0
));
});
static_for
<
0
,
xdlops_gemm
.
GetRegSizePerXdlops
(),
1
>
{}([
&
](
auto
t
)
{
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
t
));
c_thread_buf
(
Number
<
c_offset
>
{})
+=
c_thread_buf_per_scale
[
Number
<
t
>
{}]
*
type_convert
<
AccDataType
>
(
b_scale_thread_buf
[
n0
]);
});
});
});
}
}
protected:
using
Base
::
a_thread_copy_
;
using
Base
::
a_thread_desc_
;
using
Base
::
b_thread_copy_
;
using
Base
::
b_thread_desc_
;
using
Base
::
c_thread_desc_
;
};
}
// namespace ck
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2_b_scale.hpp
0 → 100644
View file @
43fd4742
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_b_scale.hpp
0 → 100644
View file @
43fd4742
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4_b_scale.hpp
0 → 100644
View file @
43fd4742
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/device/device_gemm_v2.hpp
View file @
43fd4742
...
...
@@ -77,6 +77,41 @@ struct DeviceGemmV2R1 : public BaseOperator
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
,
typename
ADataType
,
typename
BDataType
,
typename
BScaleType
,
typename
CDataType
,
index_t
ScaleBlockN
,
index_t
ScaleBlockK
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
>
struct
DeviceGemmV2BScale
:
public
BaseOperator
{
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
void
*
p_c
,
ck
::
index_t
M
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
StrideA
,
ck
::
index_t
StrideB
,
ck
::
index_t
StrideC
,
ck
::
index_t
StrideScaleB
,
const
void
*
p_b_scale
,
ck
::
index_t
KSplit
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
virtual
bool
GetPermuteB
()
=
0
;
virtual
ck
::
index_t
GetKPerBlock
()
=
0
;
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_b_scale.hpp
0 → 100644
View file @
43fd4742
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
View file @
43fd4742
...
...
@@ -44,6 +44,40 @@ __host__ __device__ inline half4_t pki4_to_half4(int q)
return
res
.
template
AsType
<
half4_t
>()[
Number
<
0
>
{}];
}
__host__
__device__
inline
half4_t
pki4_to_half4_scale
(
int
q
,
const
ck
::
half2_t
&
scale
)
{
constexpr
int
LO
=
0x000f000f
;
constexpr
int
HI
=
0x00f000f0
;
constexpr
int
EX
=
0x64006400
;
// Extract the two int4 at low bit and create two fp16 number.
int
lo
=
amd_assembly_and_or_b32
(
q
,
LO
,
EX
);
// Extract the two int4 at hight bit and create two fp16 number.
int
hi
=
amd_assembly_and_or_b32
(
q
,
HI
,
EX
);
const
int
SUB
=
0xE408E408
;
// half2 {-1032, -1032}
const
int
MUL
=
0x2c002c00
;
// half2 {1 / 16, 1 / 16}
const
int
ADD
=
0xd480d480
;
// half2 {-72, -72}
vector_type
<
half_t
,
4
>
res
;
// for two fp16 from lowbit, subtract 1032 to get correct fp16 value
res
.
template
AsType
<
half2_t
>()(
Number
<
0
>
{})
=
amd_assembly_pk_add_f16
(
bit_cast
<
half2_t
>
(
lo
),
bit_cast
<
half2_t
>
(
SUB
));
// for two fp16 from highbit, divide 16 and subtract 72 to get correct fp16 value
res
.
template
AsType
<
half2_t
>()(
Number
<
1
>
{})
=
amd_assembly_pk_fma_f16
(
bit_cast
<
half2_t
>
(
hi
),
bit_cast
<
half2_t
>
(
MUL
),
bit_cast
<
half2_t
>
(
ADD
));
asm
volatile
(
"v_pk_mul_f16 %0, %1, %2"
:
"=v"
(
res
.
template
AsType
<
half2_t
>()(
Number
<
0
>
{}))
:
"v"
(
res
.
template
AsType
<
half2_t
>()(
Number
<
0
>
{})),
"v"
(
scale
));
asm
volatile
(
"v_pk_mul_f16 %0, %1, %2"
:
"=v"
(
res
.
template
AsType
<
half2_t
>()(
Number
<
1
>
{}))
:
"v"
(
res
.
template
AsType
<
half2_t
>()(
Number
<
1
>
{})),
"v"
(
scale
));
return
res
.
template
AsType
<
half4_t
>()[
Number
<
0
>
{}];
}
__host__
__device__
inline
half2_t
pki4_to_half2
(
pk_i4_t
q
)
{
#if 1
...
...
@@ -178,6 +212,37 @@ struct PassThroughPack8
constexpr
const
static
bool
is_pack8_invocable
=
true
;
};
struct
DequantPack8
{
template
<
typename
Y
,
typename
X
,
typename
Z
>
__host__
__device__
void
operator
()(
Y
&
y
,
const
X
&
x
,
const
Z
&
z
)
const
;
__host__
__device__
constexpr
void
operator
()(
ck
::
half8_t
&
y
,
const
ck
::
pk_i4x4_t
&
x
,
const
ck
::
half2_t
&
z
)
const
{
#if 1
vector_type
<
half_t
,
8
>
result
;
result
.
template
AsType
<
half4_t
>()(
Number
<
0
>
{})
=
pki4_to_half4_scale
(
bit_cast
<
int
>
(
x
),
z
);
result
.
template
AsType
<
half4_t
>()(
Number
<
1
>
{})
=
pki4_to_half4_scale
(
bit_cast
<
int
>
(
x
)
>>
8
,
z
);
y
=
result
.
template
AsType
<
half8_t
>()[
Number
<
0
>
{}];
#else
vector_type
<
half_t
,
8
>
dst
;
vector_type
<
pk_i4_t
,
4
>
src
{
x
};
dst
.
template
AsType
<
half2_t
>()(
Number
<
0
>
{})
=
pki4_to_half2
(
src
.
template
AsType
<
pk_i4_t
>()[
Number
<
0
>
{}]);
dst
.
template
AsType
<
half2_t
>()(
Number
<
1
>
{})
=
pki4_to_half2
(
src
.
template
AsType
<
pk_i4_t
>()[
Number
<
1
>
{}]);
dst
.
template
AsType
<
half2_t
>()(
Number
<
2
>
{})
=
pki4_to_half2
(
src
.
template
AsType
<
pk_i4_t
>()[
Number
<
2
>
{}]);
dst
.
template
AsType
<
half2_t
>()(
Number
<
3
>
{})
=
pki4_to_half2
(
src
.
template
AsType
<
pk_i4_t
>()[
Number
<
3
>
{}]);
y
=
dst
.
template
AsType
<
half8_t
>()[
Number
<
0
>
{}];
#endif
}
constexpr
const
static
bool
is_pack8_invocable
=
true
;
};
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wnon-virtual-dtor"
struct
UnaryOpBase
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp
0 → 100644
View file @
43fd4742
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
View file @
43fd4742
...
...
@@ -1222,6 +1222,210 @@ struct ThreadwiseTensorSliceTransfer_v4
});
}
// Fuse scale
template
<
typename
SrcRefToOriginDisplacement
,
typename
DstOriginIdx
,
typename
SrcBuffer
,
typename
DstBuffer
>
__device__
void
Run
(
const
SrcDesc
&
,
const
SrcRefToOriginDisplacement
&
,
const
SrcBuffer
&
src_buf
,
const
DstData
&
scale
,
const
DstDesc
&
,
const
DstOriginIdx
&
,
DstBuffer
&
dst_buf
)
const
{
static_assert
(
SrcDesc
::
IsKnownAtCompileTime
()
&&
DstDesc
::
IsKnownAtCompileTime
(),
"wrong! SrcDesc and DstDesc need to known at compile-time"
);
static_assert
(
is_same
<
remove_cvref_t
<
typename
SrcBuffer
::
type
>
,
remove_cvref_t
<
SrcData
>>::
value
&&
is_same
<
remove_cvref_t
<
typename
DstBuffer
::
type
>
,
remove_cvref_t
<
DstData
>>::
value
,
"wrong! SrcBuffer or DstBuffer data type is wrong"
);
static_assert
(
DstBuffer
::
IsStaticBuffer
(),
"wrong! DstBuffer need to be StaticBuffer"
);
static_assert
(
is_known_at_compile_time
<
remove_cvref_t
<
SrcRefToOriginDisplacement
>>::
value
&&
is_known_at_compile_time
<
remove_cvref_t
<
DstOriginIdx
>>::
value
,
"wrong! SrcOriginToRefDistance and DstOriginToRefDistance need to be known "
"at compile-time"
);
// SrcDesc and DstDesc are known at compile-time
constexpr
auto
src_desc
=
remove_cvref_t
<
SrcDesc
>
{};
constexpr
auto
dst_desc
=
remove_cvref_t
<
DstDesc
>
{};
// SrcOriginToRefDisttance and DstOriginToRefDistance are known at compile-time
constexpr
auto
src_ref_to_origin_disp_idx
=
to_multi_index
(
SrcRefToOriginDisplacement
{});
constexpr
auto
dst_origin_idx
=
to_multi_index
(
DstOriginIdx
{});
// scalar per access of each dim
constexpr
auto
src_scalar_per_access
=
generate_sequence_v2
(
[
&
](
auto
i
)
constexpr
{
if
constexpr
(
i
==
SrcVectorDim
)
{
return
Number
<
SrcScalarPerVector
>
{};
}
else
{
return
Number
<
1
>
{};
}
},
Number
<
nDim
>
{});
// scalar step (if steping on SrcVectorDim) of each dim
constexpr
auto
src_scalar_step_in_vector
=
generate_sequence_v2
(
[
&
](
auto
i
)
constexpr
{
if
constexpr
(
i
==
SrcVectorDim
)
{
return
Number
<
1
>
{};
}
else
{
return
Number
<
0
>
{};
}
},
Number
<
nDim
>
{});
constexpr
auto
access_lengths
=
SliceLengths
{}
/
src_scalar_per_access
;
constexpr
auto
dim_access_order
=
DimAccessOrder
{};
constexpr
auto
ordered_access_lengths
=
container_reorder_given_new2old
(
access_lengths
,
dim_access_order
);
static_ford
<
decltype
(
ordered_access_lengths
)
>
{}([
&
](
auto
ordered_access_idx
)
{
#if 0
// TODO: unable to compile
// position in slice window
constexpr auto data_to_origin_disp_idx =
container_reorder_given_old2new(ordered_access_idx, dim_access_order) *
src_scalar_per_access;
#else
// position in slice window
constexpr
auto
data_to_origin_disp_idx
=
ordered_access_idx
.
ReorderGivenOld2New
(
dim_access_order
)
*
src_scalar_per_access
;
#endif
// src coordinate
constexpr
auto
src_ref_to_data_disp_idx
=
src_ref_to_origin_disp_idx
+
data_to_origin_disp_idx
;
constexpr
auto
src_ref_to_data_disp_coord_step
=
make_tensor_coordinate_step
(
src_desc
,
src_ref_to_data_disp_idx
);
auto
src_data_coord
=
src_ref_coord_
;
move_tensor_coordinate
(
src_desc
,
src_data_coord
,
src_ref_to_data_disp_coord_step
);
vector_type_maker_t
<
SrcData
,
SrcScalarPerVector
/
PackedSize
>
src_tmp_vector
;
using
src_vector_t
=
typename
decltype
(
src_tmp_vector
)
::
type
;
const
bool
is_src_valid
=
coordinate_has_valid_offset_assuming_visible_index_is_valid
(
src_desc
,
src_data_coord
);
// copy data from src_buf into src_tmp_vector
if
constexpr
(
SrcBuffer
::
IsDynamicBuffer
())
{
src_tmp_vector
.
template
AsType
<
src_vector_t
>()(
Number
<
0
>
{})
=
src_buf
.
template
Get
<
src_vector_t
>(
src_data_coord
.
GetOffset
()
/
PackedSize
,
is_src_valid
);
}
else
if
constexpr
(
SrcBuffer
::
IsStaticBuffer
())
{
static_assert
(
false
,
""
);
static_for
<
0
,
SrcScalarPerVector
,
1
>
{}([
&
](
auto
i
)
{
constexpr
index_t
src_offset
=
src_desc
.
CalculateOffset
(
src_ref_to_origin_disp_idx
+
data_to_origin_disp_idx
+
i
*
src_scalar_step_in_vector
);
src_tmp_vector
.
template
AsType
<
SrcData
>()(
i
)
=
src_buf
[
Number
<
src_offset
>
{}];
});
}
if
constexpr
(
is_same
<
remove_cvref_t
<
SrcData
>
,
pk_i4_t
>::
value
&&
is_same
<
remove_cvref_t
<
DstData
>
,
half_t
>::
value
)
{
// copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to
// DstData)
vector_type_maker_t
<
DstData
,
SrcScalarPerVector
>
dst_tmp_vector
;
vector_type
<
DstData
,
2
>
scale_vector
;
scale_vector
.
template
AsType
<
DstData
>()(
Number
<
0
>
{})
=
scale
;
scale_vector
.
template
AsType
<
DstData
>()(
Number
<
1
>
{})
=
scale
;
constexpr
index_t
pack_size
=
8
;
static_assert
(
SrcScalarPerVector
%
pack_size
==
0
,
""
);
using
src_v_t
=
typename
vector_type_maker_t
<
SrcData
,
pack_size
/
PackedSize
>::
type
;
using
dst_v_t
=
typename
vector_type_maker_t
<
DstData
,
pack_size
>::
type
;
using
scale_v_t
=
typename
vector_type_maker_t
<
DstData
,
2
>::
type
;
static_for
<
0
,
SrcScalarPerVector
/
pack_size
,
1
>
{}([
&
](
auto
i
)
{
ck
::
tensor_operation
::
element_wise
::
DequantPack8
{}(
dst_tmp_vector
.
template
AsType
<
dst_v_t
>()(
i
),
src_tmp_vector
.
template
AsType
<
src_v_t
>()[
i
],
scale_vector
.
template
AsType
<
scale_v_t
>()[
Number
<
0
>
{}]);
});
// copy data from dst_tmp_vector into dst_buf
static_for
<
0
,
SrcScalarPerVector
,
1
>
{}([
&
](
auto
i
)
{
constexpr
index_t
dst_offset
=
dst_desc
.
CalculateOffset
(
dst_origin_idx
+
data_to_origin_disp_idx
+
i
*
src_scalar_step_in_vector
);
dst_buf
(
Number
<
dst_offset
>
{})
=
dst_tmp_vector
.
template
AsType
<
DstData
>()[
i
];
});
}
else
if
constexpr
(
is_same
<
remove_cvref_t
<
SrcData
>
,
pk_i4_t
>::
value
&&
is_same
<
remove_cvref_t
<
DstData
>
,
f8_t
>::
value
)
{
// copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to
// DstData)
vector_type_maker_t
<
DstData
,
SrcScalarPerVector
>
dst_tmp_vector
;
constexpr
index_t
pack_size
=
8
;
static_assert
(
SrcScalarPerVector
%
pack_size
==
0
,
""
);
using
src_v_t
=
typename
vector_type_maker_t
<
SrcData
,
pack_size
/
PackedSize
>::
type
;
using
dst_v_t
=
typename
vector_type_maker_t
<
DstData
,
pack_size
>::
type
;
static_for
<
0
,
SrcScalarPerVector
/
pack_size
,
1
>
{}([
&
](
auto
i
)
{
ck
::
tensor_operation
::
element_wise
::
PassThroughPack8
{}(
dst_tmp_vector
.
template
AsType
<
dst_v_t
>()(
i
),
src_tmp_vector
.
template
AsType
<
src_v_t
>()[
i
]);
});
// copy data from dst_tmp_vector into dst_buf
static_for
<
0
,
SrcScalarPerVector
,
1
>
{}([
&
](
auto
i
)
{
constexpr
index_t
dst_offset
=
dst_desc
.
CalculateOffset
(
dst_origin_idx
+
data_to_origin_disp_idx
+
i
*
src_scalar_step_in_vector
);
dst_buf
(
Number
<
dst_offset
>
{})
=
dst_tmp_vector
.
template
AsType
<
DstData
>()[
i
];
});
}
else
if
constexpr
(
is_same
<
remove_cvref_t
<
SrcData
>
,
f8_t
>::
value
&&
is_same
<
remove_cvref_t
<
DstData
>
,
half_t
>::
value
&&
SrcScalarPerVector
%
2
==
0
)
{
// copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to
// DstData)
vector_type_maker_t
<
DstData
,
SrcScalarPerVector
>
dst_tmp_vector
;
constexpr
index_t
pack_size
=
2
;
using
dst_v_t
=
typename
vector_type_maker_t
<
DstData
,
pack_size
>::
type
;
using
src_v_t
=
typename
vector_type_maker_t
<
SrcData
,
pack_size
>::
type
;
static_for
<
0
,
SrcScalarPerVector
/
pack_size
,
1
>
{}([
&
](
auto
i
)
{
ck
::
tensor_operation
::
element_wise
::
PassThroughPack2
{}(
dst_tmp_vector
.
template
AsType
<
dst_v_t
>()(
i
),
src_tmp_vector
.
template
AsType
<
src_v_t
>()[
i
]);
});
// copy data from dst_tmp_vector into dst_buf
static_for
<
0
,
SrcScalarPerVector
,
1
>
{}([
&
](
auto
i
)
{
constexpr
index_t
dst_offset
=
dst_desc
.
CalculateOffset
(
dst_origin_idx
+
data_to_origin_disp_idx
+
i
*
src_scalar_step_in_vector
);
dst_buf
(
Number
<
dst_offset
>
{})
=
dst_tmp_vector
.
template
AsType
<
DstData
>()[
i
];
});
}
else
{
// copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to
// DstData)
vector_type_maker_t
<
DstData
,
SrcScalarPerVector
>
dst_tmp_vector
;
// TODO: if SrcData and DstData are vetor type, then static_cast may not compile
static_for
<
0
,
SrcScalarPerVector
,
1
>
{}([
&
](
auto
i
)
{
dst_tmp_vector
.
template
AsType
<
DstData
>()(
i
)
=
type_convert
<
DstData
>
(
src_tmp_vector
.
template
AsType
<
SrcData
>()[
i
]);
});
// copy data from dst_tmp_vector into dst_buf
static_for
<
0
,
SrcScalarPerVector
,
1
>
{}([
&
](
auto
i
)
{
constexpr
index_t
dst_offset
=
dst_desc
.
CalculateOffset
(
dst_origin_idx
+
data_to_origin_disp_idx
+
i
*
src_scalar_step_in_vector
);
dst_buf
(
Number
<
dst_offset
>
{})
=
dst_tmp_vector
.
template
AsType
<
DstData
>()[
i
];
});
}
});
}
template
<
typename
SrcSliceMoveStepIdx
>
__device__
void
MoveSrcSliceWindow
(
const
SrcDesc
&
,
const
SrcSliceMoveStepIdx
&
src_slice_move_step_idx
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment