Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
5936aad9
Unverified
Commit
5936aad9
authored
Dec 06, 2024
by
Mingtao Gu
Committed by
GitHub
Dec 06, 2024
Browse files
Merge pull request #3 from ROCm/i4_update
I4 update
parents
5662fc11
d06be51d
Changes
10
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
711 additions
and
359 deletions
+711
-359
example/01_gemm/gemm_xdl_fp16_pk_i4_v3_b_scale.cpp
example/01_gemm/gemm_xdl_fp16_pk_i4_v3_b_scale.cpp
+1
-1
example/66_complex_contraction_bilinear/run_complex_contraction_bilinear_example.inc
...ion_bilinear/run_complex_contraction_bilinear_example.inc
+110
-113
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_b_scale.hpp
...n/gpu/block/blockwise_gemm_pipeline_xdlops_v3_b_scale.hpp
+41
-59
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4_b_scale.hpp
...n/gpu/block/blockwise_gemm_pipeline_xdlops_v4_b_scale.hpp
+164
-179
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_b_scale.hpp
...n/gpu/device/impl/device_gemm_xdl_cshuffle_v3_b_scale.hpp
+3
-3
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
...or_operation/gpu/element/unary_element_wise_operation.hpp
+155
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp
...ration/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp
+3
-1
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
...operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
+231
-0
include/ck/utility/amd_inline_asm.hpp
include/ck/utility/amd_inline_asm.hpp
+2
-2
script/cmake-ck-dev.sh
script/cmake-ck-dev.sh
+1
-1
No files found.
example/01_gemm/gemm_xdl_fp16_pk_i4_v3_b_scale.cpp
View file @
5936aad9
...
...
@@ -39,7 +39,7 @@ using DeviceGemmV2Instance =
128
,
128
,
KPerBlock
,
8
,
32
,
32
,
32
,
2
,
2
,
4
,
1
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
2
,
128
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
...
...
example/66_complex_contraction_bilinear/run_complex_contraction_bilinear_example.inc
100755 → 100644
View file @
5936aad9
...
...
@@ -127,44 +127,47 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
switch
(
init_method
)
{
case
0
:
break
;
case
1
:
case
0
:
break
;
case
1
:
a_ms_ks_re
.
GenerateTensorValue
(
GeneratorTensor_2
<
ADataType
>
{
-
5
,
5
});
b_ns_ks_re
.
GenerateTensorValue
(
GeneratorTensor_2
<
BDataType
>
{
-
5
,
5
});
d_ms_ns_re
.
GenerateTensorValue
(
GeneratorTensor_2
<
BDataType
>
{
-
5
,
5
});
a_ms_ks_re
.
GenerateTensorValue
(
GeneratorTensor_2
<
ADataType
>
{
-
5
,
5
});
b_ns_ks_re
.
GenerateTensorValue
(
GeneratorTensor_2
<
BDataType
>
{
-
5
,
5
});
d_ms_ns_re
.
GenerateTensorValue
(
GeneratorTensor_2
<
BDataType
>
{
-
5
,
5
});
a_ms_ks_img
.
GenerateTensorValue
(
GeneratorTensor_2
<
ADataType
>
{
-
5
,
5
});
b_ns_ks_img
.
GenerateTensorValue
(
GeneratorTensor_2
<
BDataType
>
{
-
5
,
5
});
d_ms_ns_img
.
GenerateTensorValue
(
GeneratorTensor_2
<
BDataType
>
{
-
5
,
5
});
break
;
a_ms_ks_img
.
GenerateTensorValue
(
GeneratorTensor_2
<
ADataType
>
{
-
5
,
5
});
b_ns_ks_img
.
GenerateTensorValue
(
GeneratorTensor_2
<
BDataType
>
{
-
5
,
5
});
d_ms_ns_img
.
GenerateTensorValue
(
GeneratorTensor_2
<
BDataType
>
{
-
5
,
5
});
break
;
default
:
a_ms_ks_re
.
GenerateTensorValue
(
GeneratorTensor_3
<
ADataType
>
{
0.0
,
1.0
});
b_ns_ks_re
.
GenerateTensorValue
(
GeneratorTensor_3
<
BDataType
>
{
-
0.5
,
0.5
});
d_ms_ns_re
.
GenerateTensorValue
(
GeneratorTensor_3
<
BDataType
>
{
-
0.5
,
0.5
});
default
:
a_ms_ks_re
.
GenerateTensorValue
(
GeneratorTensor_3
<
ADataType
>
{
0.0
,
1.0
});
b_ns_ks_re
.
GenerateTensorValue
(
GeneratorTensor_3
<
BDataType
>
{
-
0.5
,
0.5
});
d_ms_ns_re
.
GenerateTensorValue
(
GeneratorTensor_3
<
BDataType
>
{
-
0.5
,
0.5
});
a_ms_ks_img
.
GenerateTensorValue
(
GeneratorTensor_3
<
ADataType
>
{
0.0
,
1.0
});
b_ns_ks_img
.
GenerateTensorValue
(
GeneratorTensor_3
<
BDataType
>
{
-
0.5
,
0.5
});
d_ms_ns_img
.
GenerateTensorValue
(
GeneratorTensor_3
<
BDataType
>
{
-
0.5
,
0.5
});
a_ms_ks_img
.
GenerateTensorValue
(
GeneratorTensor_3
<
ADataType
>
{
0.0
,
1.0
});
b_ns_ks_img
.
GenerateTensorValue
(
GeneratorTensor_3
<
BDataType
>
{
-
0.5
,
0.5
});
d_ms_ns_img
.
GenerateTensorValue
(
GeneratorTensor_3
<
BDataType
>
{
-
0.5
,
0.5
});
break
;
break
;
}
DeviceMem
a_device_buf_re
(
sizeof
(
ADataType
)
*
a_ms_ks_re
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
b_device_buf_re
(
sizeof
(
BDataType
)
*
b_ns_ks_re
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
d_device_buf_re
(
sizeof
(
DDataType
)
*
d_ms_ns_re
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
e_device_buf_re
(
sizeof
(
EDataType
)
*
e_ms_ns_device_result_re
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
e_device_buf_re
(
sizeof
(
EDataType
)
*
e_ms_ns_device_result_re
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
a_device_buf_img
(
sizeof
(
ADataType
)
*
a_ms_ks_img
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
b_device_buf_img
(
sizeof
(
BDataType
)
*
b_ns_ks_img
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
d_device_buf_img
(
sizeof
(
DDataType
)
*
d_ms_ns_img
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
e_device_buf_img
(
sizeof
(
EDataType
)
*
e_ms_ns_device_result_img
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
e_device_buf_img
(
sizeof
(
EDataType
)
*
e_ms_ns_device_result_img
.
mDesc
.
GetElementSpaceSize
());
// Intermediate Value For E Real and Img
DeviceMem
e_device_buf_re1
(
sizeof
(
EDataType
)
*
e_ms_ns_device_result_re
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
e_device_buf_img1
(
sizeof
(
EDataType
)
*
e_ms_ns_device_result_img
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
e_device_buf_re1
(
sizeof
(
EDataType
)
*
e_ms_ns_device_result_re
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
e_device_buf_img1
(
sizeof
(
EDataType
)
*
e_ms_ns_device_result_img
.
mDesc
.
GetElementSpaceSize
());
a_device_buf_re
.
ToDevice
(
a_ms_ks_re
.
mData
.
data
());
b_device_buf_re
.
ToDevice
(
b_ns_ks_re
.
mData
.
data
());
...
...
@@ -181,7 +184,7 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
// set zero for intermediate values
e_device_buf_re1
.
SetZero
();
e_device_buf_img1
.
SetZero
();
auto
a_element_op
=
AElementOp
{};
auto
b_element_op
=
BElementOp
{};
auto
cde_element_op
=
CDEElementOp
{
alpha
,
beta
};
...
...
@@ -189,23 +192,24 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
// device operation
// For real Intermediate Value re_1
auto
op
=
DeviceOpInstance
{};
auto
invoker
=
op
.
MakeInvoker
();
auto
argument_re1
=
op
.
MakeArgument
(
a_device_buf_re
.
GetDeviceBuffer
(),
b_device_buf_re
.
GetDeviceBuffer
(),
std
::
array
<
const
void
*
,
1
>
{
d_device_buf_re
.
GetDeviceBuffer
()},
e_device_buf_re1
.
GetDeviceBuffer
(),
a_ms_ks_lengths
,
a_ms_ks_strides
,
b_ns_ks_lengths
,
b_ns_ks_strides
,
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
1
>
{
d_ms_ns_lengths
},
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
1
>
{
d_ms_ns_strides
},
e_ms_ns_lengths
,
e_ms_ns_strides
,
a_element_op
,
b_element_op
,
cde_element_op
);
auto
op
=
DeviceOpInstance
{};
auto
invoker
=
op
.
MakeInvoker
();
auto
argument_re1
=
op
.
MakeArgument
(
a_device_buf_re
.
GetDeviceBuffer
(),
b_device_buf_re
.
GetDeviceBuffer
(),
std
::
array
<
const
void
*
,
1
>
{
d_device_buf_re
.
GetDeviceBuffer
()},
e_device_buf_re1
.
GetDeviceBuffer
(),
a_ms_ks_lengths
,
a_ms_ks_strides
,
b_ns_ks_lengths
,
b_ns_ks_strides
,
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
1
>
{
d_ms_ns_lengths
},
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
1
>
{
d_ms_ns_strides
},
e_ms_ns_lengths
,
e_ms_ns_strides
,
a_element_op
,
b_element_op
,
cde_element_op
);
if
(
!
op
.
IsSupportedArgument
(
argument_re1
))
{
...
...
@@ -216,7 +220,6 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
float
ave_time_re1
=
invoker
.
Run
(
argument_re1
,
StreamConfig
{
nullptr
,
time_kernel
});
alpha
=
-
1.
f
;
beta
=
1.
f
;
...
...
@@ -228,21 +231,22 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
// For real Intermediate Value re_2
// auto op = DeviceOpInstance{};
// auto invoker = op.MakeInvoker();
auto
argument_re2
=
op
.
MakeArgument
(
a_device_buf_img
.
GetDeviceBuffer
(),
b_device_buf_img
.
GetDeviceBuffer
(),
std
::
array
<
const
void
*
,
1
>
{
e_device_buf_re1
.
GetDeviceBuffer
()},
e_device_buf_re
.
GetDeviceBuffer
(),
a_ms_ks_lengths
,
a_ms_ks_strides
,
b_ns_ks_lengths
,
b_ns_ks_strides
,
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
1
>
{
d_ms_ns_lengths
},
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
1
>
{
d_ms_ns_strides
},
e_ms_ns_lengths
,
e_ms_ns_strides
,
a_element_op
,
b_element_op
,
cde_element_op
);
auto
argument_re2
=
op
.
MakeArgument
(
a_device_buf_img
.
GetDeviceBuffer
(),
b_device_buf_img
.
GetDeviceBuffer
(),
std
::
array
<
const
void
*
,
1
>
{
e_device_buf_re1
.
GetDeviceBuffer
()},
e_device_buf_re
.
GetDeviceBuffer
(),
a_ms_ks_lengths
,
a_ms_ks_strides
,
b_ns_ks_lengths
,
b_ns_ks_strides
,
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
1
>
{
d_ms_ns_lengths
},
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
1
>
{
d_ms_ns_strides
},
e_ms_ns_lengths
,
e_ms_ns_strides
,
a_element_op
,
b_element_op
,
cde_element_op
);
if
(
!
op
.
IsSupportedArgument
(
argument_re2
))
{
...
...
@@ -253,7 +257,6 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
float
ave_time_re2
=
invoker
.
Run
(
argument_re2
,
StreamConfig
{
nullptr
,
time_kernel
});
alpha
=
1.
f
;
beta
=
1.
f
;
...
...
@@ -261,22 +264,22 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
b_element_op
=
BElementOp
{};
cde_element_op
=
CDEElementOp
{
alpha
,
beta
};
auto
argument_img1
=
op
.
MakeArgument
(
a_device_buf_re
.
GetDeviceBuffer
(),
b
_device_buf_
img
.
GetDeviceBuffer
(),
std
::
array
<
const
void
*
,
1
>
{
d
_device_buf_img
.
GetDeviceBuffer
()
}
,
e
_device_buf_img
1
.
GetDeviceBuffer
(),
a_ms_ks_lengths
,
a_ms_ks_
stride
s
,
b_ns_ks_length
s
,
b_ns_ks_
stride
s
,
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
1
>
{
d_ms_ns_lengths
}
,
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
1
>
{
d_ms_ns_
stride
s
},
e_ms_ns_lengths
,
e_ms_ns_
stride
s
,
a_element_op
,
b
_element_op
,
cde
_element_op
);
auto
argument_img1
=
op
.
MakeArgument
(
a
_device_buf_
re
.
GetDeviceBuffer
(),
b
_device_buf_img
.
GetDeviceBuffer
(),
std
::
array
<
const
void
*
,
1
>
{
d
_device_buf_img
.
GetDeviceBuffer
()
}
,
e_device_buf_img1
.
GetDeviceBuffer
()
,
a_ms_ks_
length
s
,
a_ms_ks_stride
s
,
b_ns_ks_
length
s
,
b_ns_ks_strides
,
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
1
>
{
d_ms_ns_
length
s
},
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
1
>
{
d_ms_ns_strides
}
,
e_ms_ns_
length
s
,
e_ms_ns_strides
,
a
_element_op
,
b
_element_op
,
cde_element_op
);
if
(
!
op
.
IsSupportedArgument
(
argument_img1
))
{
...
...
@@ -290,23 +293,22 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
alpha
=
1.
f
;
beta
=
1.
f
;
auto
argument_img2
=
op
.
MakeArgument
(
a_device_buf_img
.
GetDeviceBuffer
(),
b_device_buf_re
.
GetDeviceBuffer
(),
std
::
array
<
const
void
*
,
1
>
{
e_device_buf_img1
.
GetDeviceBuffer
()},
e_device_buf_img
.
GetDeviceBuffer
(),
a_ms_ks_lengths
,
a_ms_ks_strides
,
b_ns_ks_lengths
,
b_ns_ks_strides
,
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
1
>
{
d_ms_ns_lengths
},
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
1
>
{
d_ms_ns_strides
},
e_ms_ns_lengths
,
e_ms_ns_strides
,
a_element_op
,
b_element_op
,
cde_element_op
);
auto
argument_img2
=
op
.
MakeArgument
(
a_device_buf_img
.
GetDeviceBuffer
(),
b_device_buf_re
.
GetDeviceBuffer
(),
std
::
array
<
const
void
*
,
1
>
{
e_device_buf_img1
.
GetDeviceBuffer
()},
e_device_buf_img
.
GetDeviceBuffer
(),
a_ms_ks_lengths
,
a_ms_ks_strides
,
b_ns_ks_lengths
,
b_ns_ks_strides
,
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
1
>
{
d_ms_ns_lengths
},
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
1
>
{
d_ms_ns_strides
},
e_ms_ns_lengths
,
e_ms_ns_strides
,
a_element_op
,
b_element_op
,
cde_element_op
);
if
(
!
op
.
IsSupportedArgument
(
argument_img2
))
{
...
...
@@ -317,7 +319,6 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
float
ave_time_img2
=
invoker
.
Run
(
argument_img2
,
StreamConfig
{
nullptr
,
time_kernel
});
ck
::
index_t
M
=
ck
::
accumulate_n
<
ck
::
index_t
>
(
e_ms_ns_lengths
.
begin
(),
NumDimM
,
1
,
std
::
multiplies
<>
{});
...
...
@@ -331,9 +332,9 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
std
::
size_t
num_btype
=
sizeof
(
ADataType
)
*
M
*
K
+
sizeof
(
BDataType
)
*
K
*
N
+
sizeof
(
DDataType
)
*
M
*
N
+
sizeof
(
EDataType
)
*
M
*
N
*
2
;
float
ave_time
=
ave_time_img2
+
ave_time_img1
+
ave_time_re2
+
ave_time_re1
;
float
ave_time
=
ave_time_img2
+
ave_time_img1
+
ave_time_re2
+
ave_time_re1
;
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
float
gb_per_sec
=
num_btype
/
1.E6
/
ave_time
;
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
...
...
@@ -343,7 +344,7 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
e_device_buf_img
.
FromDevice
(
e_ms_ns_device_result_img
.
mData
.
data
());
auto
isRealOk
=
0
;
auto
isImgOk
=
0
;
auto
isImgOk
=
0
;
if
(
do_verification
)
{
...
...
@@ -366,17 +367,16 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
auto
ref_op
=
ReferenceOpInstance
{};
auto
ref_invoker
=
ref_op
.
MakeInvoker
();
auto
ref_argument_re
=
ref_op
.
MakeArgument
(
a_ms_ks_re
,
b_ns_ks_re
,
c_ms_ns_host_result_re
,
a_element_op
,
b_element_op
);
auto
ref_argument_re
=
ref_op
.
MakeArgument
(
a_ms_ks_re
,
b_ns_ks_re
,
c_ms_ns_host_result_re
,
a_element_op
,
b_element_op
);
ref_invoker
.
Run
(
ref_argument_re
);
alpha
=
1.
f
;
beta
=
1.
f
;
cde_element_op
=
CDEElementOp
{
alpha
,
beta
};
for
(
size_t
m0
=
0
;
m0
<
e_ms_ns_host_result_re
.
mDesc
.
GetLengths
()[
0
];
++
m0
)
{
for
(
size_t
m1
=
0
;
m1
<
e_ms_ns_host_result_re
.
mDesc
.
GetLengths
()[
1
];
++
m1
)
...
...
@@ -395,11 +395,11 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
alpha
=
1.
f
;
beta
=
-
1.
f
;
cde_element_op
=
CDEElementOp
{
alpha
,
beta
};
auto
ref_argument_re1
=
ref_op
.
MakeArgument
(
a_ms_ks_img
,
b_ns_ks_img
,
c_ms_ns_host_result_re1
,
a_element_op
,
b_element_op
);
auto
ref_argument_re1
=
ref_op
.
MakeArgument
(
a_ms_ks_img
,
b_ns_ks_img
,
c_ms_ns_host_result_re1
,
a_element_op
,
b_element_op
);
ref_invoker
.
Run
(
ref_argument_re1
);
...
...
@@ -419,23 +419,20 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
}
}
isRealOk
=
ck
::
utils
::
check_err
(
e_ms_ns_device_result_re
,
e_ms_ns_host_result_re
)
?
0
:
1
;
isRealOk
=
ck
::
utils
::
check_err
(
e_ms_ns_device_result_re
,
e_ms_ns_host_result_re
)
?
0
:
1
;
// Img Part Verification
Tensor
<
CShuffleDataType
>
c_ms_ns_host_result_img
(
e_ms_ns_lengths
,
e_ms_ns_strides
);
Tensor
<
CShuffleDataType
>
c_ms_ns_host_result_img1
(
e_ms_ns_lengths
,
e_ms_ns_strides
);
auto
ref_argument_img
=
ref_op
.
MakeArgument
(
a_ms_ks_re
,
b_ns_ks_img
,
c_ms_ns_host_result_img
,
a_element_op
,
b_element_op
);
auto
ref_argument_img
=
ref_op
.
MakeArgument
(
a_ms_ks_re
,
b_ns_ks_img
,
c_ms_ns_host_result_img
,
a_element_op
,
b_element_op
);
ref_invoker
.
Run
(
ref_argument_img
);
alpha
=
1.
f
;
beta
=
1.
f
;
cde_element_op
=
CDEElementOp
{
alpha
,
beta
};
for
(
size_t
m0
=
0
;
m0
<
e_ms_ns_host_result_img
.
mDesc
.
GetLengths
()[
0
];
++
m0
)
...
...
@@ -454,9 +451,9 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
}
}
auto
ref_argument_img1
=
ref_op
.
MakeArgument
(
a_ms_ks_img
,
b_ns_ks_re
,
c_ms_ns_host_result_img1
,
a_element_op
,
b_element_op
);
auto
ref_argument_img1
=
ref_op
.
MakeArgument
(
a_ms_ks_img
,
b_ns_ks_re
,
c_ms_ns_host_result_img1
,
a_element_op
,
b_element_op
);
ref_invoker
.
Run
(
ref_argument_img1
);
for
(
size_t
m0
=
0
;
m0
<
e_ms_ns_host_result_img
.
mDesc
.
GetLengths
()[
0
];
++
m0
)
...
...
@@ -475,7 +472,7 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
}
}
isImgOk
=
ck
::
utils
::
check_err
(
e_ms_ns_device_result_re
,
e_ms_ns_host_result_re
)
?
0
:
1
;
isImgOk
=
ck
::
utils
::
check_err
(
e_ms_ns_device_result_re
,
e_ms_ns_host_result_re
)
?
0
:
1
;
return
(
isRealOk
&&
isImgOk
);
}
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_b_scale.hpp
View file @
5936aad9
...
...
@@ -359,13 +359,6 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra
// Initialize C
c_thread_buf
.
Clear
();
StaticBufferTupleOfVector
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
1
,
xdlops_gemm
.
GetRegSizePerXdlops
(),
true
>
c_thread_buf_per_scale
;
// Local prefetch 1
block_sync_lds
();
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
...
...
@@ -381,6 +374,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
make_tuple
(
n0
,
I0
,
I0
,
Number
<
k0
*
BMmaKStride
>
{}),
b_block_buf
,
b_scale_thread_buf
[
n0
],
b_thread_desc_
,
make_tuple
(
n0
,
I0
,
k0
,
I0
),
b_thread_buf
);
...
...
@@ -406,10 +400,31 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
c_thread_buf_per_scale
.
Clear
();
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
b_scale_thread_copy
.
Run
(
b_scale_grid_desc
,
b_scale_grid_buf
,
b_scale_thread_desc
,
make_tuple
(
n0
,
I0
),
b_scale_thread_buf
);
b_scale_thread_copy
.
MoveSrcSliceWindow
(
b_scale_grid_desc
,
b_scale_thread_copy_step
.
At
(
Number
<
0
>
{}));
});
if
((
i
+
2
)
%
num_loop_per_scale
==
0
)
{
b_scale_thread_copy
.
MoveSrcSliceWindow
(
b_scale_grid_desc
,
b_scale_thread_copy_step
.
At
(
Number
<
2
>
{}));
}
else
{
b_scale_thread_copy
.
MoveSrcSliceWindow
(
b_scale_grid_desc
,
b_scale_thread_copy_step
.
At
(
Number
<
1
>
{}));
}
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
vector_type
<
ComputeDataType
,
KPack
>
a_thread_vec
;
vector_type
<
ComputeDataType
,
KPack
>
b_thread_vec
;
...
...
@@ -426,20 +441,13 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
// constexpr index_t c_offset =
// c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf_per_scale
.
GetVectorTypeReference
(
I0
));
});
static_for
<
0
,
xdlops_gemm
.
GetRegSizePerXdlops
(),
1
>
{}([
&
](
auto
t
)
{
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
t
));
c_thread_buf
(
Number
<
c_offset
>
{})
+=
c_thread_buf_per_scale
[
Number
<
t
>
{}]
*
// type_convert<AccDataType>(a_scale_thread_buf[I0]) *
type_convert
<
AccDataType
>
(
b_scale_thread_buf
[
n0
]);
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
});
});
});
...
...
@@ -459,32 +467,12 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
make_tuple
(
n0
,
I0
,
I0
,
Number
<
k0
*
BMmaKStride
>
{}),
b_block_buf
,
b_scale_thread_buf
[
n0
],
b_thread_desc_
,
make_tuple
(
n0
,
I0
,
k0
,
I0
),
b_thread_buf
);
});
});
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
b_scale_thread_copy
.
Run
(
b_scale_grid_desc
,
b_scale_grid_buf
,
b_scale_thread_desc
,
make_tuple
(
n0
,
I0
),
b_scale_thread_buf
);
b_scale_thread_copy
.
MoveSrcSliceWindow
(
b_scale_grid_desc
,
b_scale_thread_copy_step
.
At
(
Number
<
0
>
{}));
});
if
((
i
+
2
)
%
num_loop_per_scale
==
0
)
{
b_scale_thread_copy
.
MoveSrcSliceWindow
(
b_scale_grid_desc
,
b_scale_thread_copy_step
.
At
(
Number
<
2
>
{}));
}
else
{
b_scale_thread_copy
.
MoveSrcSliceWindow
(
b_scale_grid_desc
,
b_scale_thread_copy_step
.
At
(
Number
<
1
>
{}));
}
HotLoopScheduler
();
__builtin_amdgcn_sched_barrier
(
0
);
...
...
@@ -495,10 +483,9 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra
// tail
if
constexpr
(
TailNum
==
TailNumber
::
Full
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
c_thread_buf_per_scale
.
Clear
();
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
vector_type
<
ComputeDataType
,
KPack
>
a_thread_vec
;
vector_type
<
ComputeDataType
,
KPack
>
b_thread_vec
;
...
...
@@ -514,17 +501,12 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra
using
mfma_input_type
=
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
xdlops_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf_per_scale
.
GetVectorTypeReference
(
I0
));
});
static_for
<
0
,
xdlops_gemm
.
GetRegSizePerXdlops
(),
1
>
{}([
&
](
auto
t
)
{
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
t
));
c_thread_buf
(
Number
<
c_offset
>
{})
+=
c_thread_buf_per_scale
[
Number
<
t
>
{}]
*
// type_convert<AccDataType>(a_scale_thread_buf[I0]) *
type_convert
<
AccDataType
>
(
b_scale_thread_buf
[
n0
]
);
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{})
);
});
});
});
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4_b_scale.hpp
View file @
5936aad9
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_b_scale.hpp
View file @
5936aad9
...
...
@@ -220,9 +220,9 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2BScale<ALayout,
constexpr
index_t
minimum_occupancy
=
BlkGemmPipeSched
==
BlockGemmPipelineScheduler
::
Intrawave
?
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v3
&&
MPerBlock
*
NPerBlock
*
KPerBlock
*
sizeof
(
ADataType
)
>
128
*
128
*
64
*
2
)
?
1
:
2
MPerBlock
*
NPerBlock
*
KPerBlock
*
sizeof
(
ADataType
)
<=
128
*
128
*
64
*
2
)
?
2
:
1
:
2
;
if
(
has_main_k_block_loop
)
...
...
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
View file @
5936aad9
...
...
@@ -11,6 +11,98 @@
namespace
ck
{
__host__
__device__
inline
half4_t
pki4_to_half4_scale
(
int
q
,
const
ck
::
half2_t
&
scale
)
{
constexpr
int
LO
=
0x000f000f
;
constexpr
int
HI
=
0x00f000f0
;
constexpr
int
EX
=
0x64006400
;
// Guarantee that the `(a & b) | c` operations are LOP3s.
// int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
// int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
int
lo
=
amd_assembly_and_or_b32
(
q
,
LO
,
EX
);
int
hi
=
amd_assembly_and_or_b32
(
q
,
HI
,
EX
);
// We want signed int4 outputs, hence we fuse the `-8` symmetric zero point
// directly into `SUB` and `ADD`.
constexpr
int
SUB
=
0xE408E408
;
//-8
constexpr
int
MUL
=
0x2c002c00
;
// 1/16
constexpr
int
ADD
=
0xd480d480
;
//-79
vector_type
<
half_t
,
4
>
res
;
res
.
template
AsType
<
half2_t
>()(
Number
<
0
>
{})
=
amd_assembly_pk_add_f16
(
bit_cast
<
half2_t
>
(
lo
),
bit_cast
<
half2_t
>
(
SUB
));
res
.
template
AsType
<
half2_t
>()(
Number
<
1
>
{})
=
amd_assembly_pk_fma_f16
(
bit_cast
<
half2_t
>
(
hi
),
bit_cast
<
half2_t
>
(
MUL
),
bit_cast
<
half2_t
>
(
ADD
));
asm
volatile
(
"v_pk_mul_f16 %0, %1, %2"
:
"=v"
(
res
.
template
AsType
<
half2_t
>()(
Number
<
0
>
{}))
:
"v"
(
res
.
template
AsType
<
half2_t
>()(
Number
<
0
>
{})),
"v"
(
scale
));
asm
volatile
(
"v_pk_mul_f16 %0, %1, %2"
:
"=v"
(
res
.
template
AsType
<
half2_t
>()(
Number
<
1
>
{}))
:
"v"
(
res
.
template
AsType
<
half2_t
>()(
Number
<
1
>
{})),
"v"
(
scale
));
return
res
.
template
AsType
<
half4_t
>()[
Number
<
0
>
{}];
}
// Further fuse the scale into inline assembly, sanity failed
#if 0
__host__ __device__ inline half4_t pki4_to_half4_scale(int q, const ck::half_t& scale)
{
constexpr int LO = 0x000f000f;
constexpr int HI = 0x00f000f0;
constexpr int EX = 0x64006400;
// Guarantee that the `(a & b) | c` operations are LOP3s.
// int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
// int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
int lo = amd_assembly_and_or_b32(q, LO, EX);
int hi = amd_assembly_and_or_b32(q, HI, EX);
// We want signed int4 outputs, hence we fuse the `-8` symmetric zero point
// directly into `SUB` and `ADD`.
// constexpr int SUB = 0xE408E408; //-8
// constexpr int MUL = 0x2c002c00; // 1/16
// constexpr int ADD = 0xd480d480; //-79
constexpr half_t SUB = bit_cast<half_t>(static_cast<uint16_t>(0xE408));
constexpr half_t MUL = bit_cast<half_t>(static_cast<uint16_t>(0x2c00));
constexpr half_t ADD = bit_cast<half_t>(static_cast<uint16_t>(0xd480));
vector_type<half_t, 2> scale_2;
scale_2.template AsType<half_t>()(Number<0>{}) = scale;
scale_2.template AsType<half_t>()(Number<1>{}) = scale;
vector_type<half_t, 2> sub_2;
sub_2.template AsType<half_t>()(Number<0>{}) = SUB * scale;
sub_2.template AsType<half_t>()(Number<1>{}) = SUB * scale;
vector_type<half_t, 2> mul_2;
mul_2.template AsType<half_t>()(Number<0>{}) = MUL * scale;
mul_2.template AsType<half_t>()(Number<1>{}) = MUL * scale;
vector_type<half_t, 2> add_2;
add_2.template AsType<half_t>()(Number<0>{}) = ADD * scale;
add_2.template AsType<half_t>()(Number<1>{}) = ADD * scale;
vector_type<half_t, 4> res;
res.template AsType<half2_t>()(Number<0>{}) =
amd_assembly_pk_fma_f16(bit_cast<half2_t>(lo),
scale_2.template AsType<half2_t>()(Number<0>{}),
sub_2.template AsType<half2_t>()(Number<0>{}));
res.template AsType<half2_t>()(Number<1>{}) =
amd_assembly_pk_fma_f16(bit_cast<half2_t>(hi),
mul_2.template AsType<half2_t>()(Number<0>{}),
add_2.template AsType<half2_t>()(Number<0>{}));
// asm volatile("v_pk_mul_f16 %0, %1, %2"
// : "=v"(res.template AsType<half2_t>()(Number<0>{}))
// : "v"(res.template AsType<half2_t>()(Number<0>{})), "v"(scale));
// asm volatile("v_pk_mul_f16 %0, %1, %2"
// : "=v"(res.template AsType<half2_t>()(Number<1>{}))
// : "v"(res.template AsType<half2_t>()(Number<1>{})), "v"(scale));
return res.template AsType<half4_t>()[Number<0>{}];
}
#endif
__host__
__device__
inline
half4_t
pki4_to_half4
(
int
q
)
{
constexpr
int
LO
=
0x000f000f
;
...
...
@@ -119,6 +211,69 @@ struct PassThroughPack8
result
.
template
AsType
<
half4_t
>()(
Number
<
0
>
{})
=
pki4_to_half4
(
bit_cast
<
int
>
(
x
));
result
.
template
AsType
<
half4_t
>()(
Number
<
1
>
{})
=
pki4_to_half4
(
bit_cast
<
int
>
(
x
)
>>
8
);
y
=
result
.
template
AsType
<
half8_t
>()[
Number
<
0
>
{}];
#else
vector_type
<
half_t
,
8
>
dst
;
vector_type
<
pk_i4_t
,
4
>
src
{
x
};
dst
.
template
AsType
<
half2_t
>()(
Number
<
0
>
{})
=
pki4_to_half2
(
src
.
template
AsType
<
pk_i4_t
>()[
Number
<
0
>
{}]);
dst
.
template
AsType
<
half2_t
>()(
Number
<
1
>
{})
=
pki4_to_half2
(
src
.
template
AsType
<
pk_i4_t
>()[
Number
<
1
>
{}]);
dst
.
template
AsType
<
half2_t
>()(
Number
<
2
>
{})
=
pki4_to_half2
(
src
.
template
AsType
<
pk_i4_t
>()[
Number
<
2
>
{}]);
dst
.
template
AsType
<
half2_t
>()(
Number
<
3
>
{})
=
pki4_to_half2
(
src
.
template
AsType
<
pk_i4_t
>()[
Number
<
3
>
{}]);
y
=
dst
.
template
AsType
<
half8_t
>()[
Number
<
0
>
{}];
#endif
}
constexpr
const
static
bool
is_pack8_invocable
=
true
;
};
struct
DequantPack8
{
template
<
typename
Y
,
typename
X
,
typename
Z
>
__host__
__device__
void
operator
()(
Y
&
y
,
const
X
&
x
,
const
Z
&
z
)
const
;
__host__
__device__
constexpr
void
operator
()(
ck
::
half8_t
&
y
,
const
ck
::
pk_i4x4_t
&
x
,
const
ck
::
half2_t
&
z
)
const
{
#if 0
int x_permute = 0;
int bits4_0 = (bit_cast<int>(x) >> 0) & 0xF;
int bits4_1 = (bit_cast<int>(x) >> 4) & 0xF;
int bits4_2 = (bit_cast<int>(x) >> 8) & 0xF;
int bits4_3 = (bit_cast<int>(x) >> 12) & 0xF;
int bits4_4 = (bit_cast<int>(x) >> 16) & 0xF;
int bits4_5 = (bit_cast<int>(x) >> 20) & 0xF;
int bits4_6 = (bit_cast<int>(x) >> 24) & 0xF;
int bits4_7 = (bit_cast<int>(x) >> 28) & 0xF;
x_permute |= (bits4_1 << 0);
x_permute |= (bits4_3 << 4);
x_permute |= (bits4_5 << 8);
x_permute |= (bits4_7 << 12);
x_permute |= (bits4_0 << 16);
x_permute |= (bits4_2 << 20);
x_permute |= (bits4_4 << 24);
x_permute |= (bits4_6 << 28);
vector_type<half_t, 8> result;
result.template AsType<half4_t>()(Number<0>{}) = pki4_to_half4_scale(x_permute, z);
result.template AsType<half4_t>()(Number<1>{}) = pki4_to_half4_scale(x_permute >> 8, z);
y = result.template AsType<half8_t>()[Number<0>{}];
#elif
1
vector_type
<
half_t
,
8
>
result
;
result
.
template
AsType
<
half4_t
>()(
Number
<
0
>
{})
=
pki4_to_half4_scale
(
bit_cast
<
int
>
(
x
),
z
);
result
.
template
AsType
<
half4_t
>()(
Number
<
1
>
{})
=
pki4_to_half4_scale
(
bit_cast
<
int
>
(
x
)
>>
8
,
z
);
y
=
result
.
template
AsType
<
half8_t
>()[
Number
<
0
>
{}];
#else
vector_type
<
half_t
,
8
>
dst
;
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp
View file @
5936aad9
...
...
@@ -1914,7 +1914,9 @@ struct GridwiseGemm_xdl_cshuffle_v3
make_multi_index
(
block_n_id
*
NPerBlock
/
ScaleBlockN
+
b_thread_offset
,
0
));
constexpr
auto
b_scale_thread_slice_copy_step
=
make_tuple
(
make_multi_index
(
NWaves
*
NPerXdl
,
0
),
make_multi_index
(
-
NPerBlock
,
1
));
make_tuple
(
make_multi_index
(
NWaves
*
NPerXdl
,
0
),
make_multi_index
(
-
NPerBlock
,
0
),
make_multi_index
(
-
NPerBlock
,
1
));
const
index_t
num_k_block_per_scale
=
ScaleBlockK
/
KPerBlock
;
...
...
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
View file @
5936aad9
...
...
@@ -1252,6 +1252,237 @@ struct ThreadwiseTensorSliceTransfer_v4
});
}
// Fuse scale
template
<
typename
SrcRefToOriginDisplacement
,
typename
DstOriginIdx
,
typename
SrcBuffer
,
typename
DstBuffer
>
__device__
void
Run
(
const
SrcDesc
&
,
const
SrcRefToOriginDisplacement
&
,
const
SrcBuffer
&
src_buf
,
const
DstData
&
scale
,
const
DstDesc
&
,
const
DstOriginIdx
&
,
DstBuffer
&
dst_buf
)
const
{
static_assert
(
SrcDesc
::
IsKnownAtCompileTime
()
&&
DstDesc
::
IsKnownAtCompileTime
(),
"wrong! SrcDesc and DstDesc need to known at compile-time"
);
static_assert
(
is_same
<
remove_cvref_t
<
typename
SrcBuffer
::
type
>
,
remove_cvref_t
<
SrcData
>>::
value
&&
is_same
<
remove_cvref_t
<
typename
DstBuffer
::
type
>
,
remove_cvref_t
<
DstData
>>::
value
,
"wrong! SrcBuffer or DstBuffer data type is wrong"
);
static_assert
(
DstBuffer
::
IsStaticBuffer
(),
"wrong! DstBuffer need to be StaticBuffer"
);
static_assert
(
is_known_at_compile_time
<
remove_cvref_t
<
SrcRefToOriginDisplacement
>>::
value
&&
is_known_at_compile_time
<
remove_cvref_t
<
DstOriginIdx
>>::
value
,
"wrong! SrcOriginToRefDistance and DstOriginToRefDistance need to be known "
"at compile-time"
);
// SrcDesc and DstDesc are known at compile-time
constexpr
auto
src_desc
=
remove_cvref_t
<
SrcDesc
>
{};
constexpr
auto
dst_desc
=
remove_cvref_t
<
DstDesc
>
{};
// SrcOriginToRefDisttance and DstOriginToRefDistance are known at compile-time
constexpr
auto
src_ref_to_origin_disp_idx
=
to_multi_index
(
SrcRefToOriginDisplacement
{});
constexpr
auto
dst_origin_idx
=
to_multi_index
(
DstOriginIdx
{});
// scalar per access of each dim
constexpr
auto
src_scalar_per_access
=
generate_sequence_v2
(
[
&
](
auto
i
)
constexpr
{
if
constexpr
(
i
==
SrcVectorDim
)
{
return
Number
<
SrcScalarPerVector
>
{};
}
else
{
return
Number
<
1
>
{};
}
},
Number
<
nDim
>
{});
// scalar step (if steping on SrcVectorDim) of each dim
constexpr
auto
src_scalar_step_in_vector
=
generate_sequence_v2
(
[
&
](
auto
i
)
constexpr
{
if
constexpr
(
i
==
SrcVectorDim
)
{
return
Number
<
1
>
{};
}
else
{
return
Number
<
0
>
{};
}
},
Number
<
nDim
>
{});
constexpr
auto
access_lengths
=
SliceLengths
{}
/
src_scalar_per_access
;
constexpr
auto
dim_access_order
=
DimAccessOrder
{};
constexpr
auto
ordered_access_lengths
=
container_reorder_given_new2old
(
access_lengths
,
dim_access_order
);
static_ford
<
decltype
(
ordered_access_lengths
)
>
{}([
&
](
auto
ordered_access_idx
)
{
#if 0
// TODO: unable to compile
// position in slice window
constexpr auto data_to_origin_disp_idx =
container_reorder_given_old2new(ordered_access_idx, dim_access_order) *
src_scalar_per_access;
#else
// position in slice window
constexpr
auto
data_to_origin_disp_idx
=
ordered_access_idx
.
ReorderGivenOld2New
(
dim_access_order
)
*
src_scalar_per_access
;
#endif
// src coordinate
constexpr
auto
src_ref_to_data_disp_idx
=
src_ref_to_origin_disp_idx
+
data_to_origin_disp_idx
;
constexpr
auto
src_ref_to_data_disp_coord_step
=
make_tensor_coordinate_step
(
src_desc
,
src_ref_to_data_disp_idx
);
auto
src_data_coord
=
src_ref_coord_
;
move_tensor_coordinate
(
src_desc
,
src_data_coord
,
src_ref_to_data_disp_coord_step
);
vector_type_maker_t
<
SrcData
,
SrcScalarPerVector
/
PackedSize
>
src_tmp_vector
;
using
src_vector_t
=
typename
decltype
(
src_tmp_vector
)
::
type
;
const
bool
is_src_valid
=
coordinate_has_valid_offset_assuming_visible_index_is_valid
(
src_desc
,
src_data_coord
);
// copy data from src_buf into src_tmp_vector
if
constexpr
(
SrcBuffer
::
IsDynamicBuffer
())
{
src_tmp_vector
.
template
AsType
<
src_vector_t
>()(
Number
<
0
>
{})
=
src_buf
.
template
Get
<
src_vector_t
>(
src_data_coord
.
GetOffset
()
/
PackedSize
,
is_src_valid
);
}
else
if
constexpr
(
SrcBuffer
::
IsStaticBuffer
())
{
static_assert
(
false
,
""
);
static_for
<
0
,
SrcScalarPerVector
,
1
>
{}([
&
](
auto
i
)
{
constexpr
index_t
src_offset
=
src_desc
.
CalculateOffset
(
src_ref_to_origin_disp_idx
+
data_to_origin_disp_idx
+
i
*
src_scalar_step_in_vector
);
src_tmp_vector
.
template
AsType
<
SrcData
>()(
i
)
=
src_buf
[
Number
<
src_offset
>
{}];
});
}
if
constexpr
(
is_same
<
remove_cvref_t
<
SrcData
>
,
pk_i4_t
>::
value
&&
is_same
<
remove_cvref_t
<
DstData
>
,
half_t
>::
value
)
{
// copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to
// DstData)
vector_type_maker_t
<
DstData
,
SrcScalarPerVector
>
dst_tmp_vector
;
vector_type
<
DstData
,
2
>
scale_vector
;
scale_vector
.
template
AsType
<
DstData
>()(
Number
<
0
>
{})
=
scale
;
scale_vector
.
template
AsType
<
DstData
>()(
Number
<
1
>
{})
=
scale
;
constexpr
index_t
pack_size
=
8
;
static_assert
(
SrcScalarPerVector
%
pack_size
==
0
,
""
);
using
src_v_t
=
typename
vector_type_maker_t
<
SrcData
,
pack_size
/
PackedSize
>::
type
;
using
dst_v_t
=
typename
vector_type_maker_t
<
DstData
,
pack_size
>::
type
;
using
scale_v_t
=
typename
vector_type_maker_t
<
DstData
,
2
>::
type
;
static_for
<
0
,
SrcScalarPerVector
/
pack_size
,
1
>
{}([
&
](
auto
i
)
{
ck
::
tensor_operation
::
element_wise
::
DequantPack8
{}(
dst_tmp_vector
.
template
AsType
<
dst_v_t
>()(
i
),
src_tmp_vector
.
template
AsType
<
src_v_t
>()[
i
],
scale_vector
.
template
AsType
<
scale_v_t
>()[
Number
<
0
>
{}]);
});
// copy data from dst_tmp_vector into dst_buf
static_for
<
0
,
SrcScalarPerVector
,
1
>
{}([
&
](
auto
i
)
{
constexpr
index_t
dst_offset
=
dst_desc
.
CalculateOffset
(
dst_origin_idx
+
data_to_origin_disp_idx
+
i
*
src_scalar_step_in_vector
);
dst_buf
(
Number
<
dst_offset
>
{})
=
dst_tmp_vector
.
template
AsType
<
DstData
>()[
i
];
});
}
else
if
constexpr
(
is_same
<
remove_cvref_t
<
SrcData
>
,
pk_i4_t
>::
value
&&
is_same
<
remove_cvref_t
<
DstData
>
,
f8_t
>::
value
)
{
// copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to
// DstData)
vector_type_maker_t
<
DstData
,
SrcScalarPerVector
>
dst_tmp_vector
;
constexpr
index_t
pack_size
=
8
;
static_assert
(
SrcScalarPerVector
%
pack_size
==
0
,
""
);
using
src_v_t
=
typename
vector_type_maker_t
<
SrcData
,
pack_size
/
PackedSize
>::
type
;
using
dst_v_t
=
typename
vector_type_maker_t
<
DstData
,
pack_size
>::
type
;
static_for
<
0
,
SrcScalarPerVector
/
pack_size
,
1
>
{}([
&
](
auto
i
)
{
ck
::
tensor_operation
::
element_wise
::
PassThroughPack8
{}(
dst_tmp_vector
.
template
AsType
<
dst_v_t
>()(
i
),
src_tmp_vector
.
template
AsType
<
src_v_t
>()[
i
]);
});
// copy data from dst_tmp_vector into dst_buf
static_for
<
0
,
SrcScalarPerVector
,
1
>
{}([
&
](
auto
i
)
{
constexpr
index_t
dst_offset
=
dst_desc
.
CalculateOffset
(
dst_origin_idx
+
data_to_origin_disp_idx
+
i
*
src_scalar_step_in_vector
);
dst_buf
(
Number
<
dst_offset
>
{})
=
dst_tmp_vector
.
template
AsType
<
DstData
>()[
i
];
});
}
else
if
constexpr
(
is_same
<
remove_cvref_t
<
SrcData
>
,
f8_t
>::
value
&&
is_same
<
remove_cvref_t
<
DstData
>
,
half_t
>::
value
&&
SrcScalarPerVector
%
2
==
0
)
{
// copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to
// DstData)
vector_type_maker_t
<
DstData
,
SrcScalarPerVector
>
dst_tmp_vector
;
constexpr
index_t
pack_size
=
2
;
using
dst_v_t
=
typename
vector_type_maker_t
<
DstData
,
pack_size
>::
type
;
using
src_v_t
=
typename
vector_type_maker_t
<
SrcData
,
pack_size
>::
type
;
static_for
<
0
,
SrcScalarPerVector
/
pack_size
,
1
>
{}([
&
](
auto
i
)
{
ck
::
tensor_operation
::
element_wise
::
PassThroughPack2
{}(
dst_tmp_vector
.
template
AsType
<
dst_v_t
>()(
i
),
src_tmp_vector
.
template
AsType
<
src_v_t
>()[
i
]);
});
// copy data from dst_tmp_vector into dst_buf
static_for
<
0
,
SrcScalarPerVector
,
1
>
{}([
&
](
auto
i
)
{
constexpr
index_t
dst_offset
=
dst_desc
.
CalculateOffset
(
dst_origin_idx
+
data_to_origin_disp_idx
+
i
*
src_scalar_step_in_vector
);
dst_buf
(
Number
<
dst_offset
>
{})
=
dst_tmp_vector
.
template
AsType
<
DstData
>()[
i
];
});
}
else
{
// copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to
// DstData)
vector_type_maker_t
<
DstData
,
SrcScalarPerVector
>
dst_tmp_vector
;
// TODO: if SrcData and DstData are vetor type, then static_cast may not compile
static_for
<
0
,
SrcScalarPerVector
,
1
>
{}([
&
](
auto
i
)
{
dst_tmp_vector
.
template
AsType
<
DstData
>()(
i
)
=
type_convert
<
DstData
>
(
src_tmp_vector
.
template
AsType
<
SrcData
>()[
i
]);
});
// copy data from dst_tmp_vector into dst_buf
static_for
<
0
,
SrcScalarPerVector
,
1
>
{}([
&
](
auto
i
)
{
constexpr
index_t
dst_offset
=
dst_desc
.
CalculateOffset
(
dst_origin_idx
+
data_to_origin_disp_idx
+
i
*
src_scalar_step_in_vector
);
dst_buf
(
Number
<
dst_offset
>
{})
=
dst_tmp_vector
.
template
AsType
<
DstData
>()[
i
];
});
}
});
}
template
<
typename
SrcSliceMoveStepIdx
>
__device__
void
MoveSrcSliceWindow
(
const
SrcDesc
&
,
const
SrcSliceMoveStepIdx
&
src_slice_move_step_idx
)
...
...
include/ck/utility/amd_inline_asm.hpp
View file @
5936aad9
...
...
@@ -21,14 +21,14 @@ inline __device__ int amd_assembly_and_or_b32(int a, int b, int d)
inline
__device__
half2_t
amd_assembly_pk_fma_f16
(
half2_t
a
,
half2_t
b
,
half2_t
c
)
{
half2_t
d
;
asm
volatile
(
"v_pk_fma_f16 %0, %1, %2, %3
;
\n
"
:
"=v"
(
d
)
:
"v"
(
a
),
"v"
(
b
),
"v"
(
c
));
asm
volatile
(
"v_pk_fma_f16 %0, %1, %2, %3"
:
"=v"
(
d
)
:
"v"
(
a
),
"v"
(
b
),
"v"
(
c
));
return
d
;
}
inline
__device__
half2_t
amd_assembly_pk_add_f16
(
half2_t
a
,
half2_t
b
)
{
half2_t
c
;
asm
volatile
(
"v_pk_add_f16 %0, %1, %2
;
\n
"
:
"=v"
(
c
)
:
"v"
(
a
),
"v"
(
b
));
asm
volatile
(
"v_pk_add_f16 %0, %1, %2"
:
"=v"
(
c
)
:
"v"
(
a
),
"v"
(
b
));
return
c
;
}
...
...
script/cmake-ck-dev.sh
View file @
5936aad9
...
...
@@ -17,7 +17,7 @@ fi
cmake
\
-D
CMAKE_PREFIX_PATH
=
/opt/rocm
\
-D
CMAKE_CXX_COMPILER
=
/opt/rocm/bin/hipcc
\
-D
CMAKE_HIP_FLAGS
=
"
-save-temps -gline-tables-only
-Xclang -mllvm -Xclang -enable-post-misched=0 -std=c++17
-O3 -ftemplate-backtrace-limit=0 -fPIE -Wno-gnu-line-marker
"
\
-D
CMAKE_HIP_FLAGS
=
"-Xclang -mllvm -Xclang -enable-post-misched=0 -std=c++17"
\
-D
CMAKE_BUILD_TYPE
=
Release
\
-D
BUILD_DEV
=
ON
\
-D
GPU_TARGETS
=
$GPU_TARGETS
\
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment