Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
dda18da0
Unverified
Commit
dda18da0
authored
Oct 14, 2024
by
Illia Silin
Committed by
GitHub
Oct 14, 2024
Browse files
Merge branch 'develop' into ck_migraphx_integration
parents
3b2a7aee
4cf70b36
Changes
86
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
780 additions
and
99 deletions
+780
-99
example/ck_tile/03_gemm/gemm_basic.cpp
example/ck_tile/03_gemm/gemm_basic.cpp
+40
-15
example/ck_tile/04_img2col/README.md
example/ck_tile/04_img2col/README.md
+2
-1
include/ck/config.h.in
include/ck/config.h.in
+0
-7
include/ck/host_utility/kernel_launch.hpp
include/ck/host_utility/kernel_launch.hpp
+6
-0
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_ab_scale.hpp
.../gpu/block/blockwise_gemm_pipeline_xdlops_v1_ab_scale.hpp
+5
-4
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2_ab_scale.hpp
.../gpu/block/blockwise_gemm_pipeline_xdlops_v2_ab_scale.hpp
+10
-8
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_ab_scale.hpp
.../gpu/block/blockwise_gemm_pipeline_xdlops_v3_ab_scale.hpp
+5
-4
include/ck/tensor_operation/gpu/device/device_cgemm.hpp
include/ck/tensor_operation/gpu/device/device_cgemm.hpp
+3
-3
include/ck/tensor_operation/gpu/device/impl/device_cgemm_4gemm_xdl_cshuffle.hpp
...ation/gpu/device/impl/device_cgemm_4gemm_xdl_cshuffle.hpp
+17
-1
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp
...evice/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp
...grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp
...device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp
...tion/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp
+1
-1
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
...or_operation/gpu/element/unary_element_wise_operation.hpp
+6
-0
include/ck/utility/data_type.hpp
include/ck/utility/data_type.hpp
+625
-30
include/ck/utility/math_v2.hpp
include/ck/utility/math_v2.hpp
+4
-0
include/ck_tile/core/container/thread_buffer.hpp
include/ck_tile/core/container/thread_buffer.hpp
+1
-1
include/ck_tile/host/arg_parser.hpp
include/ck_tile/host/arg_parser.hpp
+15
-5
include/ck_tile/host/convolution_parameter.hpp
include/ck_tile/host/convolution_parameter.hpp
+0
-6
include/ck_tile/host/reference/reference_gemm.hpp
include/ck_tile/host/reference/reference_gemm.hpp
+37
-10
No files found.
example/ck_tile/03_gemm/gemm_basic.cpp
View file @
dda18da0
...
@@ -43,16 +43,37 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
...
@@ -43,16 +43,37 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
// The kPadA, kPadB, kPadC & kBlockPerCu should also come from the Codegen part.
// The kPadA, kPadB, kPadC & kBlockPerCu should also come from the Codegen part.
constexpr
bool
kPadA
=
true
;
constexpr
bool
kPadA
=
true
;
constexpr
bool
kPadB
=
true
;
constexpr
bool
kPadB
=
true
;
constexpr
bool
kTilePermute
=
false
;
constexpr
int
kBlockPerCu
=
1
;
constexpr
int
kBlockPerCu
=
1
;
using
TilePartitioner
=
ck_tile
::
GemmTilePartitioner
<
GemmShape
>
;
using
TilePartitioner
=
ck_tile
::
GemmTilePartitioner
<
GemmShape
>
;
using
GemmEpilogue
=
ck_tile
::
Default2DEpilogue
<
ck_tile
::
Default2DEpilogueProblem
<
AccDataType
,
CDataType
,
kPadA
,
kPadB
>>
;
// The rank and permutation will also be generate out by the CodeGen part.
constexpr
ck_tile
::
index_t
kOutputRank
=
2
;
// Whether doing the CShuffle (transpose before the global memory), depending on the output
// layout.
constexpr
bool
CShuffleEpilogue
=
std
::
is_same_v
<
LayoutC
,
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
>
;
using
GemmEpilogue
=
std
::
conditional_t
<
CShuffleEpilogue
,
ck_tile
::
CShuffleEpilogue
<
ck_tile
::
CShuffleEpilogueProblem
<
AccDataType
,
CDataType
,
kPadA
,
kPadB
,
kTilePermute
,
kOutputRank
,
1
,
0
,
TilePartitioner
::
kM
,
TilePartitioner
::
kN
>>
,
ck_tile
::
Default2DEpilogue
<
ck_tile
::
Default2DEpilogueProblem
<
AccDataType
,
CDataType
,
kPadA
,
kPadB
>>>
;
// ToDo: Will add the codegen part to test different pipeline policies in GEMM.
// ToDo: Will add the codegen part to test different pipeline policies in GEMM.
// Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy.
// Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy.
using
Kernel
=
using
Kernel
=
ck_tile
::
GemmKernel
<
TilePartitioner
,
GemmPipeline
,
GemmEpilogue
>
;
ck_tile
::
GemmKernel
<
TilePartitioner
,
GemmPipeline
,
GemmEpilogue
,
LayoutA
,
LayoutB
,
LayoutC
>
;
auto
kargs
=
Kernel
::
MakeKargs
(
args
.
p_a
,
auto
kargs
=
Kernel
::
MakeKargs
(
args
.
p_a
,
args
.
p_b
,
args
.
p_b
,
...
@@ -255,15 +276,13 @@ int main(int argc, char* argv[])
...
@@ -255,15 +276,13 @@ int main(int argc, char* argv[])
ck_tile
::
sequence
<
M_Warp
,
N_Warp
,
K_Warp
>
,
ck_tile
::
sequence
<
M_Warp
,
N_Warp
,
K_Warp
>
,
ck_tile
::
sequence
<
M_Warp_Tile
,
N_Warp_Tile
,
K_Warp_Tile
>>
;
ck_tile
::
sequence
<
M_Warp_Tile
,
N_Warp_Tile
,
K_Warp_Tile
>>
;
using
CodegenPipelineProblem
=
ck_tile
::
BlockGemmPipelineProblem
<
ADataType
,
using
CodegenGemmTraits
=
ck_tile
::
BDataType
,
TileGemmTraits
<
kPadA
,
kPadB
,
kPadC
,
matrix_a_layout
,
matrix_b_layout
,
matrix_c_layout
>
;
AccDataType
,
CodegenGemmShape
,
kPadA
,
kPadB
,
kPadC
>
;
using
CodegenGemmPipeline
=
ck_tile
::
BlockGemmPipelineAGmemBGmemCRegV1
<
CodegenPipelineProblem
>
;
using
CodegenPipelineProblem
=
ck_tile
::
GemmPipelineProblem
<
ADataType
,
BDataType
,
AccDataType
,
CodegenGemmShape
,
CodegenGemmTraits
>
;
using
CodegenGemmPipeline
=
ck_tile
::
GemmPipelineAGmemBGmemCRegV1
<
CodegenPipelineProblem
>
;
invoke_gemm
<
ck_tile
::
half_t
,
invoke_gemm
<
ck_tile
::
half_t
,
matrix_a_layout
,
matrix_a_layout
,
...
@@ -341,7 +360,13 @@ int main(int argc, char* argv[])
...
@@ -341,7 +360,13 @@ int main(int argc, char* argv[])
ck_tile
::
HostTensor
<
CDataType
>
c_host_gpu_ref
(
c_dimensions
);
ck_tile
::
HostTensor
<
CDataType
>
c_host_gpu_ref
(
c_dimensions
);
ck_tile
::
DeviceMem
c_gpu_buf
(
c_host_gpu_ref
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
c_gpu_buf
(
c_host_gpu_ref
.
get_element_space_size_in_bytes
());
ck_tile
::
reference_gemm_gpu
<
ADataType
,
BDataType
,
AccDataType
,
CDataType
>
(
ck_tile
::
reference_gemm_gpu
<
ADataType
,
BDataType
,
AccDataType
,
CDataType
,
matrix_a_layout
,
matrix_b_layout
,
matrix_c_layout
>
(
a_buf
,
b_buf
,
c_gpu_buf
,
M
,
N
,
K
,
stride_a
,
stride_b
,
stride_c
);
a_buf
,
b_buf
,
c_gpu_buf
,
M
,
N
,
K
,
stride_a
,
stride_b
,
stride_c
);
c_buf
.
FromDevice
(
c_host_gpu_ref
.
data
());
c_buf
.
FromDevice
(
c_host_gpu_ref
.
data
());
...
...
example/ck_tile/04_img2col/README.md
View file @
dda18da0
...
@@ -6,7 +6,8 @@ This folder contains example for Image to Column using ck_tile tile-programming
...
@@ -6,7 +6,8 @@ This folder contains example for Image to Column using ck_tile tile-programming
```
```
# in the root of ck_tile
# in the root of ck_tile
mkdir build && cd build
mkdir build && cd build
sh ../script/cmake-ck-dev.sh ../ <arch> # you can replace this <arch> to gfx90a, gfx942...
# you can replace <arch> with the appropriate architecture (for example gfx90a or gfx942) or leave it blank
sh ../script/cmake-ck-dev.sh ../ <arch>
make tile_example_img2col -j
make tile_example_img2col -j
```
```
This will result in an executable
`build/bin/tile_example_img2col`
This will result in an executable
`build/bin/tile_example_img2col`
include/ck/config.h.in
View file @
dda18da0
...
@@ -97,13 +97,6 @@
...
@@ -97,13 +97,6 @@
#cmakedefine CK_ENABLE_DL_KERNELS @CK_ENABLE_DL_KERNELS@
#cmakedefine CK_ENABLE_DL_KERNELS @CK_ENABLE_DL_KERNELS@
#endif
#endif
//
// Instances supports in the current CK build
//
#ifndef CK_ENABLE_INSTANCES_ONLY
#cmakedefine CK_ENABLE_INSTANCES_ONLY @CK_ENABLE_INSTANCES_ONLY@
#endif
//
//
// CK kernels which support XDL (MI series)
// CK kernels which support XDL (MI series)
//
//
...
...
include/ck/host_utility/kernel_launch.hpp
View file @
dda18da0
...
@@ -66,6 +66,9 @@ float launch_and_time_kernel(const StreamConfig& stream_config,
...
@@ -66,6 +66,9 @@ float launch_and_time_kernel(const StreamConfig& stream_config,
hip_check_error
(
hipEventElapsedTime
(
&
total_time
,
start
,
stop
));
hip_check_error
(
hipEventElapsedTime
(
&
total_time
,
start
,
stop
));
hip_check_error
(
hipEventDestroy
(
start
));
hip_check_error
(
hipEventDestroy
(
stop
));
return
total_time
/
nrepeat
;
return
total_time
/
nrepeat
;
}
}
else
else
...
@@ -143,6 +146,9 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
...
@@ -143,6 +146,9 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
hip_check_error
(
hipEventElapsedTime
(
&
total_time
,
start
,
stop
));
hip_check_error
(
hipEventElapsedTime
(
&
total_time
,
start
,
stop
));
hip_check_error
(
hipEventDestroy
(
start
));
hip_check_error
(
hipEventDestroy
(
stop
));
return
total_time
/
nrepeat
;
return
total_time
/
nrepeat
;
}
}
else
else
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_ab_scale.hpp
View file @
dda18da0
...
@@ -308,7 +308,7 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale<BlockGemmPipelineScheduler::Intr
...
@@ -308,7 +308,7 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale<BlockGemmPipelineScheduler::Intr
typename
vector_type
<
ComputeDataType
,
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
xdlops_gemm
.
K1PerXdlops
>::
type
;
xdlops_gemm
.
template
Run
(
xdlops_gemm
.
template
Run
<
>
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf_per_scale
.
GetVectorTypeReference
(
I0
));
c_thread_buf_per_scale
.
GetVectorTypeReference
(
I0
));
...
@@ -390,7 +390,8 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale<BlockGemmPipelineScheduler::Intr
...
@@ -390,7 +390,8 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale<BlockGemmPipelineScheduler::Intr
using
mfma_input_type
=
using
mfma_input_type
=
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
xdlops_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
xdlops_gemm
.
template
Run
<
>(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf_per_scale
.
GetVectorTypeReference
(
I0
));
c_thread_buf_per_scale
.
GetVectorTypeReference
(
I0
));
});
});
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2_ab_scale.hpp
View file @
dda18da0
...
@@ -350,7 +350,7 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr
...
@@ -350,7 +350,7 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr
typename
vector_type
<
ComputeDataType
,
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
xdlops_gemm
.
K1PerXdlops
>::
type
;
xdlops_gemm
.
template
Run
(
xdlops_gemm
.
template
Run
<
>
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf_per_scale
.
GetVectorTypeReference
(
I0
));
c_thread_buf_per_scale
.
GetVectorTypeReference
(
I0
));
...
@@ -443,7 +443,7 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr
...
@@ -443,7 +443,7 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr
typename
vector_type
<
ComputeDataType
,
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
xdlops_gemm
.
K1PerXdlops
>::
type
;
xdlops_gemm
.
template
Run
(
xdlops_gemm
.
template
Run
<
>
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf_per_scale
.
GetVectorTypeReference
(
I0
));
c_thread_buf_per_scale
.
GetVectorTypeReference
(
I0
));
...
@@ -518,7 +518,8 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr
...
@@ -518,7 +518,8 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr
using
mfma_input_type
=
using
mfma_input_type
=
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
xdlops_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
xdlops_gemm
.
template
Run
<
>(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf_per_scale
.
GetVectorTypeReference
(
I0
));
c_thread_buf_per_scale
.
GetVectorTypeReference
(
I0
));
});
});
...
@@ -575,7 +576,8 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr
...
@@ -575,7 +576,8 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr
using
mfma_input_type
=
using
mfma_input_type
=
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
xdlops_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
xdlops_gemm
.
template
Run
<
>(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf_per_scale
.
GetVectorTypeReference
(
I0
));
c_thread_buf_per_scale
.
GetVectorTypeReference
(
I0
));
});
});
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_ab_scale.hpp
View file @
dda18da0
...
@@ -427,7 +427,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
...
@@ -427,7 +427,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
typename
vector_type
<
ComputeDataType
,
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
xdlops_gemm
.
K1PerXdlops
>::
type
;
xdlops_gemm
.
template
Run
(
xdlops_gemm
.
template
Run
<
>
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf_per_scale
.
GetVectorTypeReference
(
I0
));
c_thread_buf_per_scale
.
GetVectorTypeReference
(
I0
));
...
@@ -504,7 +504,8 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
...
@@ -504,7 +504,8 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
using
mfma_input_type
=
using
mfma_input_type
=
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
xdlops_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
xdlops_gemm
.
template
Run
<
>(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf_per_scale
.
GetVectorTypeReference
(
I0
));
c_thread_buf_per_scale
.
GetVectorTypeReference
(
I0
));
});
});
...
...
include/ck/tensor_operation/gpu/device/device_cgemm.hpp
View file @
dda18da0
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
#include "device_base.hpp"
#include "device_base.hpp"
...
@@ -37,7 +37,7 @@ struct DeviceCGemm : public BaseOperator
...
@@ -37,7 +37,7 @@ struct DeviceCGemm : public BaseOperator
index_t
KRaw
,
index_t
KRaw
,
index_t
StrideA
,
index_t
StrideA
,
index_t
StrideB
,
index_t
StrideB
,
index_t
StrideC
)
=
0
;
index_t
StrideC
)
const
=
0
;
};
};
template
<
typename
AElementwiseOperation
,
template
<
typename
AElementwiseOperation
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_cgemm_4gemm_xdl_cshuffle.hpp
View file @
dda18da0
...
@@ -598,10 +598,26 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
...
@@ -598,10 +598,26 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
[[
maybe_unused
]]
index_t
K
,
[[
maybe_unused
]]
index_t
K
,
[[
maybe_unused
]]
index_t
StrideA
,
[[
maybe_unused
]]
index_t
StrideA
,
[[
maybe_unused
]]
index_t
StrideB
,
[[
maybe_unused
]]
index_t
StrideB
,
index_t
StrideC
)
override
index_t
StrideC
)
const
override
{
{
return
2
*
sizeof
(
CDataType
)
*
GetCElementSpaceSize
(
M
,
N
,
StrideC
);
return
2
*
sizeof
(
CDataType
)
*
GetCElementSpaceSize
(
M
,
N
,
StrideC
);
}
}
std
::
size_t
GetWorkSpaceSize
(
const
BaseArgument
*
base_arg
)
const
override
{
const
auto
*
parg
=
dynamic_cast
<
const
Argument
*>
(
base_arg
);
if
(
!
parg
)
{
std
::
ostringstream
err
;
err
<<
"Provided argument pointer is not of an Argument class!"
<<
" In "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
;
throw
std
::
runtime_error
(
err
.
str
());
}
return
GetWorkspaceSize
(
parg
->
M
,
parg
->
N
,
parg
->
K
,
parg
->
StrideA
,
parg
->
StrideB
,
parg
->
StrideC
);
}
};
};
}
// namespace device
}
// namespace device
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp
View file @
dda18da0
...
@@ -64,7 +64,7 @@ __global__ void
...
@@ -64,7 +64,7 @@ __global__ void
const
index_t
N
=
gemm_desc_ptr
[
group_id
].
N
;
const
index_t
N
=
gemm_desc_ptr
[
group_id
].
N
;
const
index_t
K
=
gemm_desc_ptr
[
group_id
].
K
;
const
index_t
K
=
gemm_desc_ptr
[
group_id
].
K
;
if
(
M
*
N
*
K
==
0
)
if
(
M
==
0
||
N
==
0
||
K
==
0
)
return
;
return
;
const
auto
StrideAs
=
gemm_desc_ptr
[
group_id
].
StrideAs
;
const
auto
StrideAs
=
gemm_desc_ptr
[
group_id
].
StrideAs
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp
View file @
dda18da0
...
@@ -345,7 +345,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
...
@@ -345,7 +345,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
const
index_t
N
=
gemm_descs
[
i
].
N_
;
const
index_t
N
=
gemm_descs
[
i
].
N_
;
const
index_t
K
=
gemm_descs
[
i
].
K_
;
const
index_t
K
=
gemm_descs
[
i
].
K_
;
if
(
M
*
N
*
K
==
0
)
if
(
M
==
0
||
N
==
0
||
K
==
0
)
{
{
skipped_group_count_
++
;
skipped_group_count_
++
;
continue
;
continue
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp
View file @
dda18da0
...
@@ -109,7 +109,7 @@ __global__ void
...
@@ -109,7 +109,7 @@ __global__ void
N
=
gemm_desc_ptr
[
group_id
].
N
;
N
=
gemm_desc_ptr
[
group_id
].
N
;
K
=
gemm_desc_ptr
[
group_id
].
K
;
K
=
gemm_desc_ptr
[
group_id
].
K
;
if
(
M
*
N
*
K
==
0
)
if
(
M
==
0
||
N
==
0
||
K
==
0
)
{
{
grid_size_grp
=
0
;
grid_size_grp
=
0
;
continue
;
continue
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp
View file @
dda18da0
...
@@ -68,7 +68,7 @@ __global__ void
...
@@ -68,7 +68,7 @@ __global__ void
const
index_t
N
=
gemm_desc_ptr
[
group_id
].
N
;
const
index_t
N
=
gemm_desc_ptr
[
group_id
].
N
;
const
index_t
K
=
gemm_desc_ptr
[
group_id
].
K
;
const
index_t
K
=
gemm_desc_ptr
[
group_id
].
K
;
if
(
M
*
N
*
K
==
0
)
if
(
M
==
0
||
N
==
0
||
K
==
0
)
return
;
return
;
const
auto
StrideA
=
gemm_desc_ptr
[
group_id
].
StrideA
;
const
auto
StrideA
=
gemm_desc_ptr
[
group_id
].
StrideA
;
...
...
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
View file @
dda18da0
...
@@ -419,6 +419,12 @@ struct UnaryAbs
...
@@ -419,6 +419,12 @@ struct UnaryAbs
y
=
ck
::
math
::
abs
(
x
);
y
=
ck
::
math
::
abs
(
x
);
};
};
template
<
>
__host__
__device__
void
operator
()(
f8_t
&
y
,
const
f8_t
&
x
)
const
{
y
=
ck
::
type_convert
<
f8_t
>
(
ck
::
math
::
abs
(
ck
::
type_convert
<
float
>
(
x
)));
};
};
};
struct
UnarySqrt
struct
UnarySqrt
...
...
include/ck/utility/data_type.hpp
View file @
dda18da0
This diff is collapsed.
Click to expand it.
include/ck/utility/math_v2.hpp
View file @
dda18da0
...
@@ -81,6 +81,8 @@ static inline __host__ bool isnan(half_t x)
...
@@ -81,6 +81,8 @@ static inline __host__ bool isnan(half_t x)
return
(
xx
&
0x7FFF
)
>
0x7C00
;
return
(
xx
&
0x7FFF
)
>
0x7C00
;
};
};
static
inline
__host__
bool
isnan
(
f8_t
x
)
{
return
(
x
&
0x80
);
};
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
static
inline
__host__
bool
isnan
(
int4_t
x
)
static
inline
__host__
bool
isnan
(
int4_t
x
)
{
{
...
@@ -531,6 +533,8 @@ static inline __device__ bool isnan(half_t x)
...
@@ -531,6 +533,8 @@ static inline __device__ bool isnan(half_t x)
return
(
xx
&
0x7FFF
)
>
0x7C00
;
return
(
xx
&
0x7FFF
)
>
0x7C00
;
};
};
static
inline
__device__
bool
isnan
(
f8_t
x
)
{
return
(
x
&
0x80
);
};
static
inline
__device__
half_t
sqrt
(
half_t
x
)
static
inline
__device__
half_t
sqrt
(
half_t
x
)
{
{
return
static_cast
<
half_t
>
(
__builtin_amdgcn_sqrtf
(
static_cast
<
float
>
(
x
)));
return
static_cast
<
half_t
>
(
__builtin_amdgcn_sqrtf
(
static_cast
<
float
>
(
x
)));
...
...
include/ck_tile/core/container/thread_buffer.hpp
View file @
dda18da0
include/ck_tile/host/arg_parser.hpp
View file @
dda18da0
...
@@ -50,12 +50,22 @@ class ArgParser
...
@@ -50,12 +50,22 @@ class ArgParser
}
}
return
*
this
;
return
*
this
;
}
}
void
print
()
void
print
()
const
{
{
// find max key length
std
::
string
::
size_type
max_key_length
=
11
;
for
(
auto
&
key
:
keys
)
{
if
(
max_key_length
<
key
.
length
())
{
max_key_length
=
key
.
length
();
}
}
printf
(
"args:
\n
"
);
printf
(
"args:
\n
"
);
for
(
auto
&
key
:
keys
)
for
(
auto
&
key
:
keys
)
{
{
auto
value
=
input_map
[
key
]
;
auto
value
=
input_map
.
at
(
key
)
;
std
::
vector
<
std
::
string
>
help_text_lines
;
std
::
vector
<
std
::
string
>
help_text_lines
;
size_t
pos
=
0
;
size_t
pos
=
0
;
for
(
size_t
next_pos
=
value
.
help_text
.
find
(
'\n'
,
pos
);
next_pos
!=
std
::
string
::
npos
;)
for
(
size_t
next_pos
=
value
.
help_text
.
find
(
'\n'
,
pos
);
next_pos
!=
std
::
string
::
npos
;)
...
@@ -69,8 +79,7 @@ class ArgParser
...
@@ -69,8 +79,7 @@ class ArgParser
std
::
string
(
value
.
help_text
.
begin
()
+
pos
,
value
.
help_text
.
end
()));
std
::
string
(
value
.
help_text
.
begin
()
+
pos
,
value
.
help_text
.
end
()));
std
::
string
default_value
=
std
::
string
(
"(default:"
)
+
value
.
value
+
std
::
string
(
")"
);
std
::
string
default_value
=
std
::
string
(
"(default:"
)
+
value
.
value
+
std
::
string
(
")"
);
std
::
cout
<<
std
::
setw
(
1
+
max_key_length
-
value
.
name
.
length
())
<<
"-"
<<
key
std
::
cout
<<
std
::
setw
(
2
)
<<
std
::
setw
(
12
-
value
.
name
.
length
())
<<
"-"
<<
key
<<
std
::
setw
(
4
)
<<
" "
<<
help_text_lines
[
0
]
<<
" "
<<
default_value
<<
std
::
setw
(
4
)
<<
" "
<<
help_text_lines
[
0
]
<<
" "
<<
default_value
<<
std
::
endl
;
<<
std
::
endl
;
...
@@ -78,7 +87,8 @@ class ArgParser
...
@@ -78,7 +87,8 @@ class ArgParser
help_next_line
!=
help_text_lines
.
end
();
help_next_line
!=
help_text_lines
.
end
();
++
help_next_line
)
++
help_next_line
)
{
{
std
::
cout
<<
std
::
setw
(
17
)
<<
" "
<<
*
help_next_line
<<
std
::
endl
;
std
::
cout
<<
std
::
setw
(
1
+
max_key_length
+
4
)
<<
" "
<<
*
help_next_line
<<
std
::
endl
;
}
}
}
}
}
}
...
...
include/ck_tile/host/convolution_parameter.hpp
View file @
dda18da0
...
@@ -13,7 +13,6 @@ namespace conv {
...
@@ -13,7 +13,6 @@ namespace conv {
struct
ConvParam
struct
ConvParam
{
{
ConvParam
();
ConvParam
(
ck_tile
::
index_t
n_dim
,
ConvParam
(
ck_tile
::
index_t
n_dim
,
ck_tile
::
index_t
group_count
,
ck_tile
::
index_t
group_count
,
ck_tile
::
index_t
n_batch
,
ck_tile
::
index_t
n_batch
,
...
@@ -199,11 +198,6 @@ struct ConvParam
...
@@ -199,11 +198,6 @@ struct ConvParam
}
}
};
};
ConvParam
::
ConvParam
()
:
ConvParam
::
ConvParam
(
2
,
1
,
128
,
256
,
192
,
{
3
,
3
},
{
71
,
71
},
{
2
,
2
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
})
{
}
CK_TILE_HOST
std
::
string
get_conv_param_parser_helper_msg
()
CK_TILE_HOST
std
::
string
get_conv_param_parser_helper_msg
()
{
{
std
::
string
msg
;
std
::
string
msg
;
...
...
include/ck_tile/host/reference/reference_gemm.hpp
View file @
dda18da0
...
@@ -27,7 +27,9 @@ CK_TILE_HOST void reference_gemm(const HostTensor<ADataType>& a_m_k,
...
@@ -27,7 +27,9 @@ CK_TILE_HOST void reference_gemm(const HostTensor<ADataType>& a_m_k,
const
BElementOp
&
b_element_op
=
{},
const
BElementOp
&
b_element_op
=
{},
const
ACCElementOp
&
acc_element_op
=
{})
const
ACCElementOp
&
acc_element_op
=
{})
{
{
const
int
N
=
b_n_k
.
mDesc
.
get_lengths
()[
0
];
const
int
N
=
(
std
::
is_same_v
<
LayoutB
,
tensor_layout
::
gemm
::
ColumnMajor
>
)
?
b_n_k
.
mDesc
.
get_lengths
()[
0
]
:
b_n_k
.
mDesc
.
get_lengths
()[
1
];
const
int
K
=
(
std
::
is_same_v
<
LayoutA
,
tensor_layout
::
gemm
::
RowMajor
>
)
const
int
K
=
(
std
::
is_same_v
<
LayoutA
,
tensor_layout
::
gemm
::
RowMajor
>
)
?
a_m_k
.
mDesc
.
get_lengths
()[
1
]
?
a_m_k
.
mDesc
.
get_lengths
()[
1
]
:
a_m_k
.
mDesc
.
get_lengths
()[
0
];
:
a_m_k
.
mDesc
.
get_lengths
()[
0
];
...
@@ -45,20 +47,31 @@ CK_TILE_HOST void reference_gemm(const HostTensor<ADataType>& a_m_k,
...
@@ -45,20 +47,31 @@ CK_TILE_HOST void reference_gemm(const HostTensor<ADataType>& a_m_k,
ADataType
v_a
=
(
std
::
is_same_v
<
LayoutA
,
tensor_layout
::
gemm
::
RowMajor
>
)
ADataType
v_a
=
(
std
::
is_same_v
<
LayoutA
,
tensor_layout
::
gemm
::
RowMajor
>
)
?
a_element_op
(
a_m_k
(
m
,
k
))
?
a_element_op
(
a_m_k
(
m
,
k
))
:
a_element_op
(
a_m_k
(
k
,
m
));
:
a_element_op
(
a_m_k
(
k
,
m
));
BDataType
v_b
=
b_element_op
(
b_n_k
(
n
,
k
));
BDataType
v_b
=
(
std
::
is_same_v
<
LayoutB
,
tensor_layout
::
gemm
::
ColumnMajor
>
)
?
b_element_op
(
b_n_k
(
n
,
k
))
:
b_element_op
(
b_n_k
(
k
,
n
));
v_acc
+=
ck_tile
::
type_convert
<
AccDataType
>
(
v_a
)
*
v_acc
+=
ck_tile
::
type_convert
<
AccDataType
>
(
v_a
)
*
ck_tile
::
type_convert
<
AccDataType
>
(
v_b
);
ck_tile
::
type_convert
<
AccDataType
>
(
v_b
);
}
}
c_m_n
(
m
,
n
)
=
ck_tile
::
type_convert
<
CDataType
>
(
acc_element_op
(
v_acc
));
CDataType
&
c_ref
=
(
std
::
is_same_v
<
LayoutC
,
tensor_layout
::
gemm
::
RowMajor
>
)
?
c_m_n
(
m
,
n
)
:
c_m_n
(
n
,
m
);
c_ref
=
ck_tile
::
type_convert
<
CDataType
>
(
acc_element_op
(
v_acc
));
}
}
};
};
make_ParallelTensorFunctor
(
f
,
M
)(
std
::
thread
::
hardware_concurrency
());
make_ParallelTensorFunctor
(
f
,
M
)(
std
::
thread
::
hardware_concurrency
());
}
}
template
<
typename
ADataType
,
typename
BDataType
,
typename
AccDataType
,
typename
CDataType
>
template
<
typename
ADataType
,
typename
BDataType
,
typename
AccDataType
,
typename
CDataType
,
typename
LayoutA
,
typename
LayoutB
,
typename
LayoutC
>
__global__
void
naive_gemm_kernel
(
ADataType
*
A
,
__global__
void
naive_gemm_kernel
(
ADataType
*
A
,
BDataType
*
B
,
BDataType
*
B
,
CDataType
*
C
,
CDataType
*
C
,
...
@@ -76,18 +89,32 @@ __global__ void naive_gemm_kernel(ADataType* A,
...
@@ -76,18 +89,32 @@ __global__ void naive_gemm_kernel(ADataType* A,
if
(
row
<
M
&&
col
<
N
)
if
(
row
<
M
&&
col
<
N
)
{
{
AccDataType
acc
=
0.0
;
AccDataType
acc
=
0.0
;
for
(
int
k
=
0
;
k
<
K
;
++
k
)
for
(
int
k
=
0
;
k
<
K
;
++
k
)
{
{
acc
+=
static_cast
<
AccDataType
>
(
A
[
row
*
strideA
+
k
])
*
// Adjust indexing based on matrix layout
static_cast
<
AccDataType
>
(
B
[
col
*
strideB
+
k
]);
int
a_index
=
(
std
::
is_same_v
<
LayoutA
,
tensor_layout
::
gemm
::
RowMajor
>
)
?
row
*
strideA
+
k
:
k
*
strideA
+
row
;
int
b_index
=
(
std
::
is_same_v
<
LayoutB
,
tensor_layout
::
gemm
::
ColumnMajor
>
)
?
col
*
strideB
+
k
:
k
*
strideB
+
col
;
acc
+=
static_cast
<
AccDataType
>
(
A
[
a_index
])
*
static_cast
<
AccDataType
>
(
B
[
b_index
]);
}
}
C
[
row
*
strideC
+
col
]
=
acc
;
// Store as AccDataType
int
c_index
=
(
std
::
is_same_v
<
LayoutC
,
tensor_layout
::
gemm
::
RowMajor
>
)
?
row
*
strideC
+
col
:
col
*
strideC
+
row
;
C
[
c_index
]
=
acc
;
}
}
}
}
template
<
typename
ADataType
,
typename
BDataType
,
typename
AccDataType
,
typename
CDataType
>
template
<
typename
ADataType
,
typename
BDataType
,
typename
AccDataType
,
typename
CDataType
,
typename
LayoutA
,
typename
LayoutB
,
typename
LayoutC
>
void
reference_gemm_gpu
(
DeviceMem
&
a_device
,
void
reference_gemm_gpu
(
DeviceMem
&
a_device
,
DeviceMem
&
b_device
,
DeviceMem
&
b_device
,
DeviceMem
&
c_device
,
DeviceMem
&
c_device
,
...
@@ -145,7 +172,7 @@ void reference_gemm_gpu(DeviceMem& a_device,
...
@@ -145,7 +172,7 @@ void reference_gemm_gpu(DeviceMem& a_device,
int
numThreadsPerBlock
=
256
;
// Common choice for threads per block
int
numThreadsPerBlock
=
256
;
// Common choice for threads per block
int
numBlocks
=
(
totalElements
+
numThreadsPerBlock
-
1
)
/
numThreadsPerBlock
;
int
numBlocks
=
(
totalElements
+
numThreadsPerBlock
-
1
)
/
numThreadsPerBlock
;
naive_gemm_kernel
<
ADataType
,
BDataType
,
AccDataType
,
CDataType
>
naive_gemm_kernel
<
ADataType
,
BDataType
,
AccDataType
,
CDataType
,
LayoutA
,
LayoutB
,
LayoutC
>
<<<
numBlocks
,
numThreadsPerBlock
>>>
(
d_A
,
d_B
,
d_C
,
M
,
N
,
K
,
stride_a
,
stride_b
,
stride_c
);
<<<
numBlocks
,
numThreadsPerBlock
>>>
(
d_A
,
d_B
,
d_C
,
M
,
N
,
K
,
stride_a
,
stride_b
,
stride_c
);
errC
=
hipMemcpy
(
errC
=
hipMemcpy
(
c_device
.
GetDeviceBuffer
(),
d_C
,
M
*
N
*
sizeof
(
CDataType
),
hipMemcpyDeviceToHost
);
c_device
.
GetDeviceBuffer
(),
d_C
,
M
*
N
*
sizeof
(
CDataType
),
hipMemcpyDeviceToHost
);
...
...
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment