Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
2fd6c6d4
Commit
2fd6c6d4
authored
Jan 31, 2024
by
Jun Liu
Browse files
Merge branch 'develop' into amd-develop
parents
c32d3448
6651a124
Changes
78
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
3452 additions
and
274 deletions
+3452
-274
example/10_convnd_fwd_multiple_d_multiple_reduce/convnd_fwd_max_xdl_int4.cpp
...wd_multiple_d_multiple_reduce/convnd_fwd_max_xdl_int4.cpp
+2
-3
example/18_batched_gemm_reduce/batched_gemm_reduce_xdl_fp16.cpp
...e/18_batched_gemm_reduce/batched_gemm_reduce_xdl_fp16.cpp
+4
-5
example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_xdl_int4.cpp
...wd_multiple_d/grouped_conv_fwd_bias_relu_add_xdl_int4.cpp
+2
-3
example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_int4.cpp
example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_int4.cpp
+2
-3
example/35_splitK_gemm/run_splitK_gemm_example.inc
example/35_splitK_gemm/run_splitK_gemm_example.inc
+1
-1
example/35_splitK_gemm/splitK_gemm_xdl_fp16.cpp
example/35_splitK_gemm/splitK_gemm_xdl_fp16.cpp
+1
-1
example/41_grouped_conv_conv_fwd/grouped_conv_conv_fwd_xdl_int4.cpp
..._grouped_conv_conv_fwd/grouped_conv_conv_fwd_xdl_int4.cpp
+2
-3
example/48_pool3d_fwd/pool3d_fwd_common.hpp
example/48_pool3d_fwd/pool3d_fwd_common.hpp
+4
-0
example/51_avgpool3d_bwd/avgpool3d_bwd_common.hpp
example/51_avgpool3d_bwd/avgpool3d_bwd_common.hpp
+4
-0
include/ck/stream_config.hpp
include/ck/stream_config.hpp
+2
-2
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops.hpp
...or_operation/gpu/block/blockwise_gemm_pipeline_xdlops.hpp
+999
-0
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v2.hpp
...operation/gpu/device/impl/device_gemm_xdl_cshuffle_v2.hpp
+306
-0
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
+301
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v2.hpp
...nsor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v2.hpp
+1153
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp
...tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp
+30
-0
include/ck/utility/data_type.hpp
include/ck/utility/data_type.hpp
+65
-0
include/ck/utility/is_known_at_compile_time.hpp
include/ck/utility/is_known_at_compile_time.hpp
+7
-1
include/ck/wrapper/layout.hpp
include/ck/wrapper/layout.hpp
+141
-49
include/ck/wrapper/operations/copy.hpp
include/ck/wrapper/operations/copy.hpp
+137
-3
include/ck/wrapper/tensor.hpp
include/ck/wrapper/tensor.hpp
+289
-200
No files found.
example/10_convnd_fwd_multiple_d_multiple_reduce/convnd_fwd_max_xdl_int4.cpp
View file @
2fd6c6d4
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
#error Should compile this file with ck::int4_t support
#endif
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
#define BUILD_INT4_EXAMPLE
...
...
@@ -24,3 +22,4 @@ using RsDataType = ck::Tuple<R0DataType>;
#include "run_convnd_fwd_max_example.inc"
int
main
(
int
argc
,
char
*
argv
[])
{
return
!
run_convnd_fwd_max_example
(
argc
,
argv
);
}
#endif
example/18_batched_gemm_reduce/batched_gemm_reduce_xdl_fp16.cpp
View file @
2fd6c6d4
...
...
@@ -272,15 +272,14 @@ int main(int argc, char* argv[])
{
for
(
int
m
=
0
;
m
<
M
;
++
m
)
{
auto
reduce0_acc
=
reduce0_op
.
GetIdentityValue
<
ReduceAccDataType
>
();
auto
reduce1_acc
=
reduce1_op
.
GetIdentityValue
<
ReduceAccDataType
>
();
auto
reduce0_acc
=
reduce0_op
.
GetIdentityValue
<
ReduceAccDataType
>
();
auto
reduce1_acc
=
reduce1_op
.
GetIdentityValue
<
ReduceAccDataType
>
();
ReduceAccDataType
d0_val
=
0
;
ReduceAccDataType
d1_val
=
0
;
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
auto
c_val
=
ck
::
type_convert
<
ReduceAccDataType
>
(
c_g_m_n_host_result
(
batch
,
m
,
n
));
ReduceAccDataType
d0_val
;
ReduceAccDataType
d1_val
;
UnaryIdenticElementOp
{}(
d0_val
,
c_val
);
UnarySquareElementOp
{}(
d1_val
,
c_val
);
...
...
example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_xdl_int4.cpp
View file @
2fd6c6d4
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
#error Should compile this file with ck::int4_t support
#endif
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
#include "common.hpp"
...
...
@@ -29,3 +27,4 @@ using OutElementOp = ck::tensor_operation::element_wise::AddReluAdd;
#include "run_grouped_conv_fwd_bias_relu_add_example.inc"
int
main
(
int
argc
,
char
*
argv
[])
{
return
!
run_grouped_conv_fwd_bias_relu_add_example
(
argc
,
argv
);
}
#endif
example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_int4.cpp
View file @
2fd6c6d4
...
...
@@ -9,9 +9,7 @@ Gemm + Gemm fused operation. Computes C_m_o = A_m_k * B0_k_n * B1_n_o
Gemm1
*/
#ifndef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
#error Should compile this file with ck::int4_t support
#endif
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
#include <iostream>
#include <numeric>
...
...
@@ -144,3 +142,4 @@ static_assert(sizeof(ck::int4_t) == sizeof(int8_t));
#endif
int
main
(
int
argc
,
char
*
argv
[])
{
return
run_batched_gemm_gemm_example
(
argc
,
argv
)
?
0
:
1
;
}
#endif
example/35_splitK_gemm/run_splitK_gemm_example.inc
View file @
2fd6c6d4
...
...
@@ -157,7 +157,7 @@ bool run_splitK_gemm(const ProblemSize& problem_size, const ExecutionConfig& con
if
(
config
.
time_kernel
)
{
float
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
config
.
time_kernel
});
float
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
config
.
time_kernel
,
1
});
std
::
size_t
flop
=
std
::
size_t
(
2
)
*
M
*
N
*
K
;
std
::
size_t
num_btype
=
...
...
example/35_splitK_gemm/splitK_gemm_xdl_fp16.cpp
View file @
2fd6c6d4
...
...
@@ -42,7 +42,7 @@ using AElementOp = PassThrough;
using
BElementOp
=
PassThrough
;
using
CElementOp
=
PassThrough
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
KPadding
;
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGemmXdlSplitKCShuffle
// clang-format off
...
...
example/41_grouped_conv_conv_fwd/grouped_conv_conv_fwd_xdl_int4.cpp
View file @
2fd6c6d4
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
#error Should compile this file with ck::int4_t support
#endif
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
#include <cstdlib>
#include <iostream>
...
...
@@ -120,3 +118,4 @@ static_assert(sizeof(ck::int4_t) == sizeof(int8_t));
#endif
int
main
(
int
argc
,
char
*
argv
[])
{
return
run_grouped_conv_conv_fwd_example
(
argc
,
argv
)
?
0
:
1
;
}
#endif
example/48_pool3d_fwd/pool3d_fwd_common.hpp
View file @
2fd6c6d4
...
...
@@ -32,6 +32,8 @@ std::vector<ck::index_t> f_tensor_strides_ncdhw(ck::index_t N_,
return
{
C_
*
D
*
H
*
W
,
D
*
H
*
W
,
H
*
W
,
W
,
1
_uz
};
else
if
constexpr
(
ck
::
is_same
<
decltype
(
layout
),
ck
::
tensor_layout
::
convolution
::
NDHWC
>::
value
)
return
{
D
*
C_
*
H
*
W
,
1
_uz
,
C_
*
H
*
W
,
W
*
C_
,
C_
};
throw
std
::
runtime_error
(
"Pool3d_fwd: problem with layout. "
);
return
{
0
,
0
,
0
,
0
,
0
};
};
template
<
typename
TensorLayout
>
...
...
@@ -53,6 +55,8 @@ HostTensorDescriptor f_host_tensor_descriptor(std::size_t N_,
return
HostTensorDescriptor
({
N_
,
C_
,
D
,
H
,
W
},
{
D
*
C_
*
H
*
W
,
1
_uz
,
C_
*
H
*
W
,
W
*
C_
,
C_
});
}
throw
std
::
runtime_error
(
"Pool3d_fwd: problem with layout. "
);
return
HostTensorDescriptor
({
0
,
0
,
0
,
0
,
0
},
{
0
,
0
,
0
,
0
,
0
});
};
template
<
typename
DevicePoolFwdInstance
,
...
...
example/51_avgpool3d_bwd/avgpool3d_bwd_common.hpp
View file @
2fd6c6d4
...
...
@@ -26,6 +26,8 @@ std::vector<ck::index_t> f_tensor_strides_ncdhw(ck::index_t N_,
return
{
C_
*
D
*
H
*
W
,
D
*
H
*
W
,
H
*
W
,
W
,
1
_uz
};
else
if
constexpr
(
ck
::
is_same
<
decltype
(
layout
),
ck
::
tensor_layout
::
convolution
::
NDHWC
>::
value
)
return
{
D
*
C_
*
H
*
W
,
1
_uz
,
C_
*
H
*
W
,
W
*
C_
,
C_
};
throw
std
::
runtime_error
(
"Avgpool3d_bwd: problem with layout. "
);
return
{
0
,
0
,
0
,
0
,
0
};
};
template
<
typename
TensorLayout
>
...
...
@@ -47,6 +49,8 @@ HostTensorDescriptor f_host_tensor_descriptor(std::size_t N_,
return
HostTensorDescriptor
({
N_
,
C_
,
D
,
H
,
W
},
{
D
*
C_
*
H
*
W
,
1
_uz
,
C_
*
H
*
W
,
W
*
C_
,
C_
});
}
throw
std
::
runtime_error
(
"Avgpool3d_bwd: problem with layout. "
);
return
HostTensorDescriptor
({
0
,
0
,
0
,
0
,
0
},
{
0
,
0
,
0
,
0
,
0
});
};
template
<
typename
DevicePoolBwdInstance
,
...
...
include/ck/stream_config.hpp
View file @
2fd6c6d4
...
...
@@ -11,6 +11,6 @@ struct StreamConfig
hipStream_t
stream_id_
=
nullptr
;
bool
time_kernel_
=
false
;
int
log_level_
=
0
;
int
cold_niters_
=
1
;
int
nrepeat_
=
1
0
;
int
cold_niters_
=
5
;
int
nrepeat_
=
5
0
;
};
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops.hpp
0 → 100644
View file @
2fd6c6d4
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v2.hpp
0 → 100644
View file @
2fd6c6d4
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v2.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
// Note: inter-wave loop scheduler is rolled out to c-shuffle version first. Becuase non c-shuffle
// version currently has compiler issues with register spill which further causes validation
// failures.
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
,
typename
ADataType
,
typename
BDataType
,
typename
CDataType
,
typename
GemmAccDataType
,
typename
CShuffleDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
GemmSpecialization
GemmSpec
,
index_t
NumGemmKPrefetchStage
,
index_t
BlockSize
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
AK1
,
index_t
BK1
,
index_t
MPerXDL
,
index_t
NPerXDL
,
index_t
MXdlPerWave
,
index_t
NXdlPerWave
,
typename
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
index_t
ABlockTransferSrcVectorDim
,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
ABlockTransferDstScalarPerVector_AK1
,
bool
ABlockLdsExtraM
,
typename
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
typename
BBlockTransferThreadClusterArrangeOrder
,
typename
BBlockTransferSrcAccessOrder
,
index_t
BBlockTransferSrcVectorDim
,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
BBlockTransferDstScalarPerVector_BK1
,
bool
BBlockLdsExtraN
,
index_t
CShuffleMXdlPerWavePerShuffle
,
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
LoopScheduler
LoopSched
=
make_default_loop_scheduler
(),
PipelineVersion
PipelineVer
=
PipelineVersion
::
v1
,
typename
ComputeTypeA
=
CDataType
,
typename
ComputeTypeB
=
ComputeTypeA
>
struct
DeviceGemm_Xdl_CShuffleV2
:
public
DeviceGemm
<
ALayout
,
BLayout
,
CLayout
,
ADataType
,
BDataType
,
CDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>
{
using
DeviceOp
=
DeviceGemm_Xdl_CShuffleV2
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
// GridwiseGemm
using
GridwiseGemm
=
GridwiseGemm_xdl_cshuffle_v2
<
ALayout
,
BLayout
,
CLayout
,
ADataType
,
BDataType
,
GemmAccDataType
,
CShuffleDataType
,
CDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
GemmSpec
,
InMemoryDataOperationEnum
::
Set
,
NumGemmKPrefetchStage
,
BlockSize
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
AK1
,
BK1
,
MPerXDL
,
NPerXDL
,
MXdlPerWave
,
NXdlPerWave
,
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcVectorDim
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_AK1
,
false
,
ABlockLdsExtraM
,
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcVectorDim
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_BK1
,
false
,
BBlockLdsExtraN
,
CShuffleMXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
LoopSched
,
PipelineVer
,
ComputeTypeA
,
ComputeTypeB
>
;
using
Argument
=
typename
GridwiseGemm
::
Argument
;
// Invoker
struct
Invoker
:
public
BaseInvoker
{
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
if
(
stream_config
.
log_level_
>
0
)
{
arg
.
Print
();
}
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
))
{
throw
std
::
runtime_error
(
"wrong! GridwiseGemm has invalid setting"
);
}
index_t
gdx
,
gdy
,
gdz
;
std
::
tie
(
gdx
,
gdy
,
gdz
)
=
GridwiseGemm
::
CalculateGridSize
(
arg
.
M
,
arg
.
N
);
float
ave_time
=
0
;
const
auto
K
=
GridwiseGemm
::
CalculateAK0
(
arg
.
K
)
*
AK1
;
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K
)
==
3
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v2
<
GridwiseGemm
,
true
>
;
ave_time
=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
gdx
,
gdy
,
gdz
),
dim3
(
BlockSize
),
0
,
arg
);
}
else
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v2
<
GridwiseGemm
,
true
,
2
>
;
ave_time
=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
gdx
,
gdy
,
gdz
),
dim3
(
BlockSize
),
0
,
arg
);
}
return
ave_time
;
}
// polymorphic
float
Run
(
const
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
);
}
};
static
constexpr
bool
IsValidCompilationParameter
()
{
// TODO: properly implement this check
return
true
;
}
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
if
(
!
ck
::
is_xdl_supported
())
{
return
false
;
}
if
((
arg
.
K
%
AK1
!=
0
||
arg
.
K
%
BK1
!=
0
)
&&
!
(
GemmSpec
==
GemmSpecialization
::
MKPadding
||
GemmSpec
==
GemmSpecialization
::
NKPadding
||
GemmSpec
==
GemmSpecialization
::
MNKPadding
||
GemmSpec
==
GemmSpecialization
::
KPadding
))
{
return
false
;
}
return
GridwiseGemm
::
CheckValidity
(
arg
);
}
// polymorphic
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
{
return
IsSupportedArgument
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
static
auto
MakeArgument
(
const
ADataType
*
p_a
,
const
BDataType
*
p_b
,
CDataType
*
p_c
,
index_t
M
,
index_t
N
,
index_t
K
,
index_t
StrideA
,
index_t
StrideB
,
index_t
StrideC
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
)
{
return
Argument
{
p_a
,
p_b
,
p_c
,
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
// polymorphic
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
void
*
p_c
,
index_t
M
,
index_t
N
,
index_t
K
,
index_t
StrideA
,
index_t
StrideB
,
index_t
StrideC
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
)
override
{
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
ADataType
*>
(
p_a
),
static_cast
<
const
BDataType
*>
(
p_b
),
static_cast
<
CDataType
*>
(
p_c
),
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
);
}
// polymorphic
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
}
// polymorphic
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
std
::
map
<
LoopScheduler
,
std
::
string
>
LoopSchedToString
{
{
LoopScheduler
::
Default
,
"Default"
},
{
LoopScheduler
::
Interwave
,
"Interwave"
}};
std
::
map
<
PipelineVersion
,
std
::
string
>
PipelineVersionToString
{{
PipelineVersion
::
v1
,
"v1"
},
{
PipelineVersion
::
v2
,
"v2"
}};
// clang-format off
str
<<
"DeviceGemm_Xdl_CShuffleV2"
<<
"<"
<<
getGemmSpecializationString
(
GemmSpec
)
<<
", "
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
<<
NPerBlock
<<
", "
<<
KPerBlock
<<
", "
<<
AK1
<<
", "
<<
BK1
<<
", "
<<
MPerXDL
<<
", "
<<
NPerXDL
<<
", "
<<
MXdlPerWave
<<
", "
<<
NXdlPerWave
<<
", "
<<
ABlockTransferSrcScalarPerVector
<<
", "
<<
BBlockTransferSrcScalarPerVector
<<
", "
<<
CShuffleMXdlPerWavePerShuffle
<<
", "
<<
CShuffleNXdlPerWavePerShuffle
<<
">"
<<
" LoopScheduler: "
<<
LoopSchedToString
[
LoopSched
]
<<
", "
<<
"PipelineVersion: "
<<
PipelineVersionToString
[
PipelineVer
];
// clang-format on
return
str
.
str
();
}
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
View file @
2fd6c6d4
...
...
@@ -134,6 +134,11 @@ struct BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, void>
__host__
__device__
BlockToCTileMap_M00_N0_M01Adapt
(
index_t
M
,
index_t
N
,
index_t
M01
=
8
)
:
M_
(
M
),
N_
(
N
),
M01_
(
M01
)
{
#if 0
if(get_thread_global_1d_id()==0){
printf("Ctor called, M= %d, N= %d, M01 = %d\n", M_, N_, M01_);
}
#endif
}
template
<
typename
CGridDesc_M_N
>
...
...
@@ -252,6 +257,302 @@ struct BlockToCTileMap_M00_N0_M01Adapt : BlockToCTileMap_M00_N0_M01Adapt<MPerBlo
BlockToCTileMap_M00_N0_M01Adapt
;
};
// Rows of column-vectors
// This C-tile map dynamically adjusts M01 when C-tile index is out of range
template
<
index_t
GroupNum
,
index_t
MPerBlock
,
index_t
NPerBlock
,
typename
CGridDesc_M_N
=
void
>
struct
BlockToCTileMap_Grouped_M00_N0_M01Adapt
;
template
<
index_t
GroupNum
,
index_t
MPerBlock
,
index_t
NPerBlock
>
struct
BlockToCTileMap_Grouped_M00_N0_M01Adapt
<
GroupNum
,
MPerBlock
,
NPerBlock
,
void
>
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
__host__
__device__
BlockToCTileMap_Grouped_M00_N0_M01Adapt
()
=
default
;
__host__
__device__
BlockToCTileMap_Grouped_M00_N0_M01Adapt
(
const
BlockToCTileMap_Grouped_M00_N0_M01Adapt
&
)
=
default
;
__host__
__device__
BlockToCTileMap_Grouped_M00_N0_M01Adapt
(
BlockToCTileMap_Grouped_M00_N0_M01Adapt
&&
)
=
default
;
__host__
__device__
BlockToCTileMap_Grouped_M00_N0_M01Adapt
&
operator
=
(
const
BlockToCTileMap_Grouped_M00_N0_M01Adapt
&
)
=
default
;
__host__
__device__
BlockToCTileMap_Grouped_M00_N0_M01Adapt
&
operator
=
(
BlockToCTileMap_Grouped_M00_N0_M01Adapt
&&
)
=
default
;
__host__
__device__
BlockToCTileMap_Grouped_M00_N0_M01Adapt
(
index_t
M
,
index_t
N
,
index_t
M01
=
8
)
:
M_
(
M
),
N_
(
N
),
M01_
(
M01
)
{
#if 0
if(get_thread_global_1d_id()==0){
printf("Ctor called, M= %d, N= %d, M01 = %d\n", M_, N_, M01_);
}
#endif
}
template
<
typename
CGridDesc_M_N
>
__host__
__device__
BlockToCTileMap_Grouped_M00_N0_M01Adapt
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
,
index_t
M01
=
8
)
:
BlockToCTileMap_Grouped_M00_N0_M01Adapt
(
c_grid_desc_m_n
.
GetLength
(
I0
),
c_grid_desc_m_n
.
GetLength
(
I1
),
M01
)
{
}
__host__
static
constexpr
index_t
CalculateGridSize
(
index_t
M
,
index_t
N
)
{
const
auto
M0
=
math
::
integer_divide_ceil
(
M
,
MPerBlock
);
const
auto
N0
=
math
::
integer_divide_ceil
(
N
,
NPerBlock
);
return
M0
*
N0
;
}
template
<
typename
CGridDesc_M_N
>
__host__
static
constexpr
index_t
CalculateGridSize
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
{
return
CalculateGridSize
(
c_grid_desc_m_n
.
GetLength
(
I0
),
c_grid_desc_m_n
.
GetLength
(
I1
));
}
template
<
typename
CGridDesc_M_N
>
__host__
bool
CheckValidity
(
const
CGridDesc_M_N
&
/* c_grid_desc_m_n */
)
const
{
return
true
;
}
template
<
typename
TopIdx
>
__host__
__device__
constexpr
auto
CalculateBottomIndex
(
const
TopIdx
&
idx_top
)
const
{
auto
block_1d_id
=
idx_top
[
I0
];
const
auto
M0
=
math
::
integer_divide_ceil
(
M_
,
MPerBlock
);
const
auto
N0
=
math
::
integer_divide_ceil
(
N_
,
NPerBlock
);
block_1d_id
=
block_1d_id
%
(
M0
*
N0
);
// swallow batch index
const
auto
group_size
=
math
::
integer_divide_ceil
(
M0
*
N0
,
GroupNum
);
auto
group_id
=
block_1d_id
%
GroupNum
;
auto
remap_block_1d_id
=
group_id
*
group_size
+
block_1d_id
/
GroupNum
;
index_t
idx_N0
=
remap_block_1d_id
%
N0
;
index_t
idx_M0
=
remap_block_1d_id
/
N0
;
const
auto
M01_adapt
=
(
idx_M0
<
M0
-
M0
%
M01_
)
?
M01_
:
M0
%
M01_
;
index_t
idx_M00
=
idx_M0
/
M01_
;
index_t
idx_M01
=
idx_M0
%
M01_
;
index_t
idx_N0_M01_local
=
idx_N0
+
idx_M01
*
N0
;
/**
* idxN0
*
* |< mtx N >|
*
* NPerBlock NPerBlock NPerBlock NPerBlock
* N_0 N_1 N_2 N_3
* - |-----------|-----------|-----------|-----|-----|-
* ^ | - - 0 |/----> 2 | | | |
* | | | / | | | | | M_0 MPerBlock
* | M | /| | | | | |
* |-0---|---/-|-----|-----|-----------|-----|-----|-
* | 1 | / | | | blockid | | |
* idxM0 | | | / | V | 5 | | | M_1 MPerBlock
* | - V 1 | - 3 | | | |
* |-----------|-----------|-----------|-----|-----|-
* mtx M | | | | | |
* | | | | | | M_2 MPerBlock
* | | | | | |
* |-----------|-----------|-----------|-----|-----|-
* | | | | | |
* | | | | | | M_3 MPerBlock
* | | | | | |
* |-----------|-----------|-----------|-----|-----|-
* V | | | | | |
* - |-----------|-----------|-----------|-----|-----|- M_4 MPerBlock
* | | | | | |
* |-----------|-----------|-----------|-----|-----|-
* Example:
* assume:
* M0 = 5
* N0 = 4
* block_1d_id = 5
* M01 = 2
*
* idx_N0 = 1
* idx_M0 = 1
* M01_adapt = 2
* idx_M00 = 0
* idx_M01 = 1
* idx_N0_M01_local = 5
* output {1, 2}
*/
return
make_tuple
(
idx_N0_M01_local
%
M01_adapt
+
idx_M00
*
M01_
,
idx_N0_M01_local
/
M01_adapt
);
}
template
<
typename
CTileIdx
,
typename
CTileDim
>
__host__
__device__
bool
ValidCTileIndex
(
const
CTileIdx
&
/* c_tile_idx */
,
const
CTileDim
&
/* c_tile_dim */
)
const
{
return
true
;
// always valid provided that user gets grid size from CalculateGridSize()
}
private:
index_t
M_
;
index_t
N_
;
index_t
M01_
;
};
// keep the redundant type argument for backward compatibility
template
<
index_t
GroupNum
,
index_t
MPerBlock
,
index_t
NPerBlock
,
typename
CGridDesc_M_N
>
struct
BlockToCTileMap_Grouped_M00_N0_M01Adapt
:
BlockToCTileMap_Grouped_M00_N0_M01Adapt
<
GroupNum
,
MPerBlock
,
NPerBlock
,
void
>
{
using
BlockToCTileMap_Grouped_M00_N0_M01Adapt
<
GroupNum
,
MPerBlock
,
NPerBlock
,
void
>::
BlockToCTileMap_Grouped_M00_N0_M01Adapt
;
};
// columns of row-vectors
// This C-tile map dynamically adjusts N01 when C-tile index is out of range
template
<
index_t
MPerBlock
,
index_t
NPerBlock
,
typename
CGridDesc_M_N
=
void
>
struct
BlockToCTileMap_N00_M0_N01Adapt
;
template
<
index_t
MPerBlock
,
index_t
NPerBlock
>
struct
BlockToCTileMap_N00_M0_N01Adapt
<
MPerBlock
,
NPerBlock
,
void
>
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
__host__
__device__
BlockToCTileMap_N00_M0_N01Adapt
()
=
default
;
__host__
__device__
BlockToCTileMap_N00_M0_N01Adapt
(
const
BlockToCTileMap_N00_M0_N01Adapt
&
)
=
default
;
__host__
__device__
BlockToCTileMap_N00_M0_N01Adapt
(
BlockToCTileMap_N00_M0_N01Adapt
&&
)
=
default
;
__host__
__device__
BlockToCTileMap_N00_M0_N01Adapt
&
operator
=
(
const
BlockToCTileMap_N00_M0_N01Adapt
&
)
=
default
;
__host__
__device__
BlockToCTileMap_N00_M0_N01Adapt
&
operator
=
(
BlockToCTileMap_N00_M0_N01Adapt
&&
)
=
default
;
__host__
__device__
BlockToCTileMap_N00_M0_N01Adapt
(
index_t
M
,
index_t
N
,
index_t
N01
=
8
)
:
M_
(
M
),
N_
(
N
),
N01_
(
N01
)
{
#if 0
if(get_thread_global_1d_id()==0){
printf("Ctor called, M= %d, N= %d, N01 = %d\n", M_, N_, N01_);
}
#endif
}
template
<
typename
CGridDesc_M_N
>
__host__
__device__
BlockToCTileMap_N00_M0_N01Adapt
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
,
index_t
N01
=
8
)
:
BlockToCTileMap_N00_M0_N01Adapt
(
c_grid_desc_m_n
.
GetLength
(
I0
),
c_grid_desc_m_n
.
GetLength
(
I1
),
N01
)
{
}
__host__
static
constexpr
index_t
CalculateGridSize
(
index_t
M
,
index_t
N
)
{
const
auto
M0
=
math
::
integer_divide_ceil
(
M
,
MPerBlock
);
const
auto
N0
=
math
::
integer_divide_ceil
(
N
,
NPerBlock
);
return
M0
*
N0
;
}
template
<
typename
CGridDesc_M_N
>
__host__
static
constexpr
index_t
CalculateGridSize
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
{
return
CalculateGridSize
(
c_grid_desc_m_n
.
GetLength
(
I0
),
c_grid_desc_m_n
.
GetLength
(
I1
));
}
template
<
typename
CGridDesc_M_N
>
__host__
bool
CheckValidity
(
const
CGridDesc_M_N
&
/* c_grid_desc_m_n */
)
const
{
return
true
;
}
template
<
typename
TopIdx
>
__host__
__device__
constexpr
auto
CalculateBottomIndex
(
const
TopIdx
&
idx_top
)
const
{
auto
block_1d_id
=
idx_top
[
I0
];
const
auto
M0
=
math
::
integer_divide_ceil
(
M_
,
MPerBlock
);
const
auto
N0
=
math
::
integer_divide_ceil
(
N_
,
NPerBlock
);
block_1d_id
=
block_1d_id
%
(
M0
*
N0
);
// swallow batch index
index_t
idx_M0
=
block_1d_id
%
M0
;
index_t
idx_N0
=
block_1d_id
/
M0
;
const
auto
N01_adapt
=
(
idx_N0
<
N0
-
N0
%
N01_
)
?
N01_
:
N0
%
N01_
;
index_t
idx_N00
=
idx_N0
/
N01_
;
index_t
idx_N01
=
idx_N0
%
N01_
;
index_t
idx_M0_N01_local
=
idx_M0
+
idx_N01
*
M0
;
/**
* idxN0
*
* |< mtx N >|
*
* |<---N01--->|
* - |-----------|-----------|-----------|-----|-----|-
* ^ | 0 ----------> 1 | | | |
* | | / | | | | M_0 MPerBlock
* | / | | | |
* |------/----------------|-----------|-----|-----|-
* | | | | | | |
* idxM0 | V | | | | | M_1 MPerBlock
* | 2 ----------> 3 | | | |
* |-----------|-----------|-----------|-----|-----|-
* mtx M | | blockid | | | |
* | | 5 | | | | M_2 MPerBlock
* | | | | | |
* |-----------|-----------|-----------|-----|-----|-
* | | | | | |
* | | | | | | M_3 MPerBlock
* | | | | | |
* |-----------|-----------|-----------|-----|-----|-
* V | | | | | |
* - |-----------|-----------|-----------|-----|-----|- M_4 MPerBlock
* | | | | | |
* |-----------|-----------|-----------|-----|-----|-
* NPerBlock NPerBlock NPerBlock NPerBlock
* N_0 N_1 N_2 N_3
* Example:
* assume:
* N0 = 5
* M0 = 4
* block_1d_id = 5
* N01 = 2
*
* idx_M0 = 1
* idx_N0 = 1
* N01_adapt = 2
* idx_N00 = 0
* idx_N01 = 1
* idx_M0_N01_local = 5
* output {2, 1}
*/
return
make_tuple
(
idx_M0_N01_local
/
N01_adapt
,
idx_M0_N01_local
%
N01_adapt
+
idx_N00
*
N01_
);
}
template
<
typename
CTileIdx
,
typename
CTileDim
>
__host__
__device__
bool
ValidCTileIndex
(
const
CTileIdx
&
/* c_tile_idx */
,
const
CTileDim
&
/* c_tile_dim */
)
const
{
return
true
;
// always valid provided that user gets grid size from CalculateGridSize()
}
private:
index_t
M_
;
index_t
N_
;
index_t
N01_
;
};
// 2D slices of column-vectors in 3D space
// This C-tile map dynamically adjusts M01 when C-tile index is out of range
template
<
index_t
MPerBlock
,
index_t
NPerBlock
,
typename
CGridDesc_M_N
>
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v2.hpp
0 → 100644
View file @
2fd6c6d4
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp
View file @
2fd6c6d4
...
...
@@ -268,6 +268,21 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
}
else
if
constexpr
(
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
KPadding
)
{
const
auto
a_grid_desc_m_kpad
=
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_pass_through_transform
(
M
),
make_right_pad_transform
(
K
,
KPad
-
K
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
transform_tensor_descriptor
(
a_grid_desc_m_kpad
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
KBatch
,
K0Padded
,
K1
)),
make_pass_through_transform
(
M
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
}
else
{
return
transform_tensor_descriptor
(
...
...
@@ -329,6 +344,21 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
}
else
if
constexpr
(
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
KPadding
)
{
const
auto
b_grid_desc_kpad_n
=
transform_tensor_descriptor
(
b_grid_desc_k_n
,
make_tuple
(
make_right_pad_transform
(
K
,
KPad
-
K
),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
transform_tensor_descriptor
(
b_grid_desc_kpad_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
KBatch
,
K0Padded
,
K1
)),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
}
else
{
return
transform_tensor_descriptor
(
...
...
include/ck/utility/data_type.hpp
View file @
2fd6c6d4
...
...
@@ -189,6 +189,7 @@ struct vector_type<T, 1>
}
};
int
static
err
=
0
;
template
<
typename
T
>
struct
vector_type
<
T
,
2
>
{
...
...
@@ -221,6 +222,10 @@ struct vector_type<T, 2>
{
return
data_
.
d2x1_
;
}
else
{
return
err
;
}
}
template
<
typename
X
>
...
...
@@ -236,6 +241,10 @@ struct vector_type<T, 2>
{
return
data_
.
d2x1_
;
}
else
{
return
err
;
}
}
};
...
...
@@ -278,6 +287,10 @@ struct vector_type<T, 4>
{
return
data_
.
d4x1_
;
}
else
{
return
err
;
}
}
template
<
typename
X
>
...
...
@@ -298,6 +311,10 @@ struct vector_type<T, 4>
{
return
data_
.
d4x1_
;
}
else
{
return
err
;
}
}
};
...
...
@@ -347,6 +364,10 @@ struct vector_type<T, 8>
{
return
data_
.
d8x1_
;
}
else
{
return
err
;
}
}
template
<
typename
X
>
...
...
@@ -372,6 +393,10 @@ struct vector_type<T, 8>
{
return
data_
.
d8x1_
;
}
else
{
return
err
;
}
}
};
...
...
@@ -428,6 +453,10 @@ struct vector_type<T, 16>
{
return
data_
.
d16x1_
;
}
else
{
return
err
;
}
}
template
<
typename
X
>
...
...
@@ -458,6 +487,10 @@ struct vector_type<T, 16>
{
return
data_
.
d16x1_
;
}
else
{
return
err
;
}
}
};
...
...
@@ -520,6 +553,10 @@ struct vector_type<T, 32>
{
return
data_
.
d32x1_
;
}
else
{
return
err
;
}
}
template
<
typename
X
>
...
...
@@ -554,6 +591,10 @@ struct vector_type<T, 32>
{
return
data_
.
d32x1_
;
}
else
{
return
err
;
}
}
};
...
...
@@ -623,6 +664,10 @@ struct vector_type<T, 64>
{
return
data_
.
d64x1_
;
}
else
{
return
err
;
}
}
template
<
typename
X
>
...
...
@@ -662,6 +707,10 @@ struct vector_type<T, 64>
{
return
data_
.
d64x1_
;
}
else
{
return
err
;
}
}
};
...
...
@@ -737,6 +786,10 @@ struct vector_type<T, 128>
{
return
data_
.
d128x1_
;
}
else
{
return
err
;
}
}
template
<
typename
X
>
...
...
@@ -780,6 +833,10 @@ struct vector_type<T, 128>
{
return
data_
.
d128x1_
;
}
else
{
return
err
;
}
}
};
...
...
@@ -861,6 +918,10 @@ struct vector_type<T, 256>
{
return
data_
.
d256x1_
;
}
else
{
return
err
;
}
}
template
<
typename
X
>
...
...
@@ -908,6 +969,10 @@ struct vector_type<T, 256>
{
return
data_
.
d256x1_
;
}
else
{
return
err
;
}
}
};
...
...
include/ck/utility/is_known_at_compile_time.hpp
View file @
2fd6c6d4
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -19,6 +19,12 @@ struct is_known_at_compile_time<index_t>
static
constexpr
bool
value
=
false
;
};
template
<
>
struct
is_known_at_compile_time
<
unsigned
int
>
{
static
constexpr
bool
value
=
false
;
};
template
<
>
struct
is_known_at_compile_time
<
long_index_t
>
{
...
...
include/ck/wrapper/layout.hpp
View file @
2fd6c6d4
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2023
-2024
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -14,22 +14,28 @@ namespace wrapper {
* \tparam Shape Tuple of Number<> (for compile-time layout) or index_t
* (dynamic layout). It is possible to pass nested shapes
* (e.g. ((4, 2), 2)), nested dimensions are merged.
* \tparam Un
nest
edDescriptorType Tensor descriptor for unnested shape dims.
* \tparam Un
roll
edDescriptorType Tensor descriptor for unnested shape dims.
*/
template
<
typename
Shape
,
typename
Un
nest
edDescriptorType
>
template
<
typename
Shape
,
typename
Un
roll
edDescriptorType
>
struct
Layout
{
private:
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
// Generate default idxs tuple (idx with all merged nested shapes)
/**
* \brief Generate default indices tuple (idx with all merged nested shapes)
*
* \param shape Shape to align.
* \return Multi idx tuple with zeros.
*/
template
<
typename
...
Ts
>
__host__
__device__
constexpr
static
auto
GenerateDefaultIdxsTuple
(
const
Tuple
<
Ts
...
>&
)
__host__
__device__
constexpr
static
auto
GenerateDefaultIdxsTuple
([[
maybe_unused
]]
const
Tuple
<
Ts
...
>&
shape
)
{
return
generate_tuple
(
[
&
](
auto
)
{
if
constexpr
(
!
Unnest
edDescriptorType
::
IsKnownAtCompileTime
())
if
constexpr
(
!
remove_cvref_t
<
Unroll
edDescriptorType
>
::
IsKnownAtCompileTime
())
{
// runtime layout
return
index_t
(
0
);
...
...
@@ -43,11 +49,18 @@ struct Layout
Number
<
Tuple
<
Ts
...
>::
Size
()
>
{});
}
// Generate LowerDims in Compile-time for MergeTrasform using passed Type
// If element of Tuple<Ts...> is also tuple, then merge (generate sequence for merge)
// If tuple is element, then pass through (sequence with one element)
/**
* \brief Generate lower dims in compile-time for the Merge transform using
* provided type. If element of nested Tuple<Ts...> is also a tuple, then
* merge (generate sequence for merge). If tuple is element, then pass
* through (sequence with one element).
*
* \param shape Shape to align.
* \return LowerDims for MergeTrasform.
*/
template
<
typename
Idx
,
typename
...
Ts
>
__host__
__device__
constexpr
static
auto
GenerateLowerDim
(
const
Tuple
<
Ts
...
>&
)
__host__
__device__
constexpr
static
auto
GenerateLowerDim
([[
maybe_unused
]]
const
Tuple
<
Ts
...
>&
shape
)
{
if
constexpr
(
Idx
::
value
==
0
)
{
...
...
@@ -87,11 +100,17 @@ struct Layout
}
}
// Iterate over nested tuples in shape
// Unroll nested tuples to align Tuple<ShapeDims...> to Tuple<IdxDims...>
// Example idx: (1, 1), 1, 1
// Example shape: (2, (2, 2)), 2, (2, 2)
// Unrolled shape: 2, (2, 2), 2, (2, 2)
/**
* \brief Iterate over the nested tuples in the shape.
* Unroll nested tuples to align Tuple<ShapeDims...> to Tuple<IdxDims...>
* Example idx: (1, 1), 1, 1
* Example shape: (2, (2, 2)), 2, (2, 2)
* Unrolled shape: 2, (2, 2), 2, (2, 2)
*
* \param shape Layout shape.
* \param idx Idx to align.
* \return Algined shape.
*/
template
<
typename
...
ShapeDims
,
typename
...
IdxDims
>
__host__
__device__
constexpr
static
auto
AlignShapeToIdx
(
const
Tuple
<
ShapeDims
...
>&
shape
,
const
Tuple
<
IdxDims
...
>&
idx
)
...
...
@@ -126,6 +145,13 @@ struct Layout
}
}
/**
* \brief Merge descriptor to 1D.
*
* \param shape Layout shape.
* \param desc Descriptor to merge.
* \return 1D descriptor.
*/
template
<
typename
...
ShapeDims
,
typename
DescriptorToMerge
>
__host__
__device__
constexpr
static
auto
MakeMerge1d
(
const
Tuple
<
ShapeDims
...
>&
shape
,
const
DescriptorToMerge
&
desc
)
...
...
@@ -137,18 +163,41 @@ struct Layout
const
auto
lower_dims
=
make_tuple
(
MergeElemsSequence
::
Reverse
());
const
auto
upper_dims
=
make_tuple
(
Sequence
<
0
>
{});
// Merge to 1d
return
transform_tensor_descriptor
(
desc
,
make_tuple
(
make_merge_transform
(
merge_elems
)),
lower_dims
,
upper_dims
);
if
constexpr
(
!
remove_cvref_t
<
UnrolledDescriptorType
>::
IsKnownAtCompileTime
())
{
return
transform_tensor_descriptor
(
desc
,
make_tuple
(
make_merge_transform
(
merge_elems
)),
lower_dims
,
upper_dims
);
}
else
{
// If the descriptor is known at the compilation time,
// use `make_merge_transform_v1_carry_check` because it doesn't use
// memcpy.
return
transform_tensor_descriptor
(
desc
,
make_tuple
(
make_merge_transform_v1_carry_check
(
merge_elems
)),
lower_dims
,
upper_dims
);
}
}
// Merge nested shape dims when corresponding index is also nested.
// Input desc shape: 2, 2, 2, 2, 2, 2
// Example idx: 1, 1, 1, 1
// Example shape: 2, (2, 2), 2, (2, 2)
// Merged shape: 2, 4, 2, 4
/**
* \brief Merge nested shape dims when corresponding index is also merged.
* Input desc shape: 2, 2, 2, 2, 2, 2
* Example idx: 1, 1, 1, (1, 1)
* Example shape: 2, (2, 2), 2, (2, 2)
* Merged shape: 2, 4, 2, 2, 2
*
* \param shape Layout shape.
* \param idxs Indexes to align descriptor.
* \param desc Descriptor to merge.
* \return Aligned descriptor to idx.
*/
template
<
typename
...
ShapeDims
,
typename
...
IdxDims
,
typename
DescriptorToMerge
>
__host__
__device__
constexpr
static
auto
CreateMergedDescriptor
(
const
Tuple
<
ShapeDims
...
>&
shape
,
const
Tuple
<
IdxDims
...
>&
,
DescriptorToMerge
&
desc
)
__host__
__device__
constexpr
static
auto
CreateMergedDescriptor
(
const
Tuple
<
ShapeDims
...
>&
shape
,
[[
maybe_unused
]]
const
Tuple
<
IdxDims
...
>&
idxs
,
DescriptorToMerge
&
desc
)
{
const
auto
transforms
=
generate_tuple
(
[
&
](
auto
i
)
{
...
...
@@ -160,7 +209,17 @@ struct Layout
// If shape element is tuple and idx element is Number, then merge
// Unroll and reverse tuple to traverse column-major
const
auto
merge_elems
=
TupleReverse
(
UnrollNestedTuple
(
shape
.
At
(
i
)));
return
make_merge_transform
(
merge_elems
);
if
constexpr
(
!
remove_cvref_t
<
UnrolledDescriptorType
>::
IsKnownAtCompileTime
())
{
return
make_merge_transform
(
merge_elems
);
}
else
{
// If the descriptor is known at the compilation time,
// use `make_merge_transform_v1_carry_check` because
// it doesn't use memcpy.
return
make_merge_transform_v1_carry_check
(
merge_elems
);
}
}
else
{
...
...
@@ -185,14 +244,23 @@ struct Layout
}
using
Descriptor1dType
=
remove_cvref_t
<
decltype
(
MakeMerge1d
(
Shape
{},
Un
nest
edDescriptorType
{}))
>
;
remove_cvref_t
<
decltype
(
MakeMerge1d
(
Shape
{},
Un
roll
edDescriptorType
{}))
>
;
using
DefaultIdxsTupleType
=
remove_cvref_t
<
decltype
(
GenerateDefaultIdxsTuple
(
Shape
{}))
>
;
public:
/**
* \brief Transform descriptor to align to passed indexes.
*
* \param shape Layout shape.
* \param idxs Indexes to align descriptor.
* \param naive_descriptor Descriptor to merge.
* \return Aligned descriptor to idx.
*/
template
<
typename
...
ShapeDims
,
typename
...
IdxDims
>
__host__
__device__
constexpr
static
auto
TransformDesc
(
const
Tuple
<
ShapeDims
...
>&
shape
,
const
Tuple
<
IdxDims
...
>&
idx
,
const
Un
nest
edDescriptorType
&
naive_descriptor
)
const
Tuple
<
IdxDims
...
>&
idx
s
,
const
Un
roll
edDescriptorType
&
naive_descriptor
)
{
if
constexpr
(
Tuple
<
IdxDims
...
>::
Size
()
==
I1
)
{
...
...
@@ -208,19 +276,18 @@ struct Layout
static_assert
(
Tuple
<
ShapeDims
...
>::
Size
()
==
Tuple
<
IdxDims
...
>::
Size
(),
"Idx rank and Shape rank must be the same (except 1d)."
);
// Unroll while IdxDims is nested
const
auto
aligned_shape
=
AlignShapeToIdx
(
shape
,
idx
);
const
auto
aligned_shape
=
AlignShapeToIdx
(
shape
,
idx
s
);
// Transform correct form of shape
return
CreateMergedDescriptor
(
aligned_shape
,
UnrollNestedTuple
(
idx
),
naive_descriptor
);
return
CreateMergedDescriptor
(
aligned_shape
,
UnrollNestedTuple
(
idx
s
),
naive_descriptor
);
}
}
using
MergedNestsDescriptorType
=
remove_cvref_t
<
decltype
(
TransformDesc
(
Shape
{},
DefaultIdxsTupleType
{},
Un
nest
edDescriptorType
{}))
>
;
Shape
{},
DefaultIdxsTupleType
{},
Un
roll
edDescriptorType
{}))
>
;
public:
__host__
__device__
constexpr
auto
GetElementSpaceSize
()
const
{
return
un
nest
ed_descriptor_
.
GetElementSpaceSize
();
return
un
roll
ed_descriptor_
.
GetElementSpaceSize
();
}
__host__
__device__
Layout
()
=
delete
;
...
...
@@ -232,16 +299,15 @@ struct Layout
* \param unnested_descriptor Descriptor
*/
__host__
__device__
constexpr
Layout
(
const
Shape
&
shape
,
const
Un
nest
edDescriptorType
&
unnested_descriptor
)
:
shape_
(
shape
)
const
Un
roll
edDescriptorType
&
unnested_descriptor
)
:
unrolled_descriptor_
(
unnested_descriptor
),
shape_
(
shape
)
{
// Construct if runtime mode
if
constexpr
(
!
Unnest
edDescriptorType
::
IsKnownAtCompileTime
())
if
constexpr
(
!
remove_cvref_t
<
Unroll
edDescriptorType
>
::
IsKnownAtCompileTime
())
{
unnested_descriptor_
=
unnested_descriptor
;
descriptor_1d_
=
MakeMerge1d
(
shape_
,
unnested_descriptor_
);
descriptor_1d_
=
MakeMerge1d
(
shape_
,
unrolled_descriptor_
);
merged_nests_descriptor_
=
TransformDesc
(
shape_
,
DefaultIdxsTupleType
{},
un
nest
ed_descriptor_
);
TransformDesc
(
shape_
,
DefaultIdxsTupleType
{},
un
roll
ed_descriptor_
);
}
}
...
...
@@ -254,9 +320,9 @@ struct Layout
template
<
typename
Idxs
>
__host__
__device__
constexpr
index_t
operator
()()
const
{
static_assert
(
Unnest
edDescriptorType
::
IsKnownAtCompileTime
(),
static_assert
(
remove_cvref_t
<
Unroll
edDescriptorType
>
::
IsKnownAtCompileTime
(),
"Compiletime operator used on runtime layout."
);
using
TransformedDesc
=
decltype
(
TransformDesc
(
Shape
{},
Idxs
{},
Un
nest
edDescriptorType
{}));
using
TransformedDesc
=
decltype
(
TransformDesc
(
Shape
{},
Idxs
{},
Un
roll
edDescriptorType
{}));
using
UnrolledIdx
=
decltype
(
UnrollNestedTuple
(
Idxs
{}));
return
TransformedDesc
{}.
CalculateOffset
(
UnrolledIdx
{});
}
...
...
@@ -283,7 +349,7 @@ struct Layout
else
{
// Custom index, need to transform descriptor
const
auto
transformed_desc
=
TransformDesc
(
shape_
,
Idx
,
un
nest
ed_descriptor_
);
const
auto
transformed_desc
=
TransformDesc
(
shape_
,
Idx
,
un
roll
ed_descriptor_
);
return
transformed_desc
.
CalculateOffset
(
UnrollNestedTuple
(
Idx
));
}
}
...
...
@@ -350,29 +416,55 @@ struct Layout
}
/**
* \brief Get default descriptor (with the same size as Shape)
* \brief Get descriptor with all nested dimensions merged.
* Example, shape: ((2, 2), 2)
* Descriptor lengths: (4, 2)
*
* \return Default descriptor.
* \note The size of merged descriptor is the same as Layout's shape.
*
* \return Merged nests descriptor.
*/
__host__
__device__
constexpr
const
MergedNestsDescriptorType
&
GetDefaultDescriptor
()
const
__host__
__device__
constexpr
const
MergedNestsDescriptorType
&
GetMergedNestingDescriptor
()
const
{
return
merged_nests_descriptor_
;
}
/**
* \brief Get descriptor with all dimensions are merged (1D).
* Example, shape: ((2, 2), 2)
* Descriptor lengths: (8)
*
* \return 1D descriptor.
*/
__host__
__device__
constexpr
const
Descriptor1dType
&
Get1DDescriptor
()
const
{
return
descriptor_1d_
;
}
/**
* \brief Get unnested descriptor (with unrolled dims)
* Example, shape: ((2, 2), 2)
* Descriptor lengths: (2, 2, 2)
*
* \return Flatten descriptor.
* \return Flatten
ed
descriptor.
*/
__host__
__device__
constexpr
const
Un
nest
edDescriptorType
&
GetUn
nest
edDescriptor
()
const
__host__
__device__
constexpr
const
Un
roll
edDescriptorType
&
GetUn
roll
edDescriptor
()
const
{
return
un
nest
ed_descriptor_
;
return
un
roll
ed_descriptor_
;
}
private:
UnnestedDescriptorType
unnested_descriptor_
;
// All dimensions are unrolled
UnrolledDescriptorType
unrolled_descriptor_
;
// 1D descriptor
Descriptor1dType
descriptor_1d_
;
// All nesting are merged
MergedNestsDescriptorType
merged_nests_descriptor_
;
// Example, shape: ((2, 2), 2)
// UnrolledDescriptorType lengths: (2, 2, 2)
// Descriptor1dType lengths: (8)
// MergedNestsDescriptorType lengths: (4, 2)
const
Shape
shape_
;
};
...
...
include/ck/wrapper/operations/copy.hpp
View file @
2fd6c6d4
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2023
-2024
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "../utils/tensor_utils.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace
ck
{
namespace
wrapper
{
/**
* \brief Perform generic copy between two tensors
. Tensors must have the
* same size.
* \brief Perform generic copy between two tensors
partitions (threadwise copy).
*
Tensors must have the
same size.
*
* \param src_tensor Source tensor.
* \param dst_tensor Destination tensor.
...
...
@@ -37,5 +42,134 @@ __host__ __device__ void copy(const SrcTensorType& src_tensor, DstTensorType& ds
}
}
/**
* \brief Perform optimized copy between two tensors partitions (threadwise copy).
* Tensors must have the same size.
*
* \tparam DimAccessOrderTuple Tuple with dimension access order.
* \tparam VectorDim Dimension for vectorized read and write.
* \tparam ScalarPerVector Number of scalar per vectorized read and write.
* \param src_tensor Source tensor.
* \param dst_tensor Destination tensor.
*/
template
<
typename
DimAccessOrderTuple
,
index_t
VectorDim
,
index_t
ScalarPerVector
,
typename
SrcTensorType
,
typename
DstTensorType
>
__device__
void
copy
(
const
SrcTensorType
&
src_tensor
,
DstTensorType
&
dst_tensor
)
{
static_assert
(
is_detected
<
is_tuple
,
DimAccessOrderTuple
>::
value
);
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
const
auto
&
in_grid_desc
=
layout
(
src_tensor
).
GetUnrolledDescriptor
();
const
auto
&
out_grid_desc
=
layout
(
dst_tensor
).
GetUnrolledDescriptor
();
using
SrcShapeType
=
remove_cvref_t
<
decltype
(
shape
(
src_tensor
))
>
;
constexpr
index_t
num_dims
=
SrcShapeType
::
Size
();
constexpr
auto
thread_slice_lengths
=
generate_sequence_v2
([](
auto
I
)
{
return
size
(
SrcShapeType
{}.
At
(
I
));
},
Number
<
num_dims
>
{});
constexpr
auto
dim_access_order
=
generate_sequence_v2
(
[](
auto
I
)
{
return
DimAccessOrderTuple
{}.
At
(
I
);
},
Number
<
num_dims
>
{});
if
constexpr
(
SrcTensorType
::
IsDynamicBuffer
&&
DstTensorType
::
IsDynamicBuffer
)
{
// Perform a copy between DynamicBuffers
auto
transfer
=
ThreadwiseTensorSliceTransfer_v7
<
Tuple
<
typename
SrcTensorType
::
TensorElementType
>
,
Tuple
<
typename
DstTensorType
::
TensorElementType
>
,
decltype
(
tie
(
in_grid_desc
)),
decltype
(
tie
(
out_grid_desc
)),
tensor_operation
::
element_wise
::
PassThrough
,
Sequence
<
static_cast
<
index_t
>
(
InMemoryDataOperationEnum
::
Set
)
>
,
decltype
(
thread_slice_lengths
),
decltype
(
dim_access_order
),
VectorDim
,
ScalarPerVector
,
Sequence
<
false
>
,
Sequence
<
false
>>
{
in_grid_desc
,
make_tuple
(
src_tensor
.
GetMultiIdxOffsets
()),
out_grid_desc
,
make_tuple
(
dst_tensor
.
GetMultiIdxOffsets
()),
tensor_operation
::
element_wise
::
PassThrough
{}};
transfer
.
Run
(
tie
(
in_grid_desc
),
tie
(
src_tensor
.
GetBuffer
()),
tie
(
out_grid_desc
),
tie
(
dst_tensor
.
GetBuffer
()));
}
else
if
constexpr
(
!
SrcTensorType
::
IsDynamicBuffer
&&
DstTensorType
::
IsDynamicBuffer
)
{
// Perform copy from StaticBuffer to DynamicBuffer
const
auto
src_slice_origin_idxs
=
generate_tuple
([
&
](
auto
)
{
return
I0
;
},
Number
<
num_dims
>
{});
auto
transfer
=
ThreadwiseTensorSliceTransfer_v1r3
<
typename
SrcTensorType
::
TensorElementType
,
typename
DstTensorType
::
TensorElementType
,
remove_cvref_t
<
decltype
(
in_grid_desc
)
>
,
remove_cvref_t
<
decltype
(
out_grid_desc
)
>
,
tensor_operation
::
element_wise
::
PassThrough
,
decltype
(
thread_slice_lengths
),
decltype
(
dim_access_order
),
VectorDim
,
ScalarPerVector
,
InMemoryDataOperationEnum
::
Set
,
I1
,
true
>
{
out_grid_desc
,
dst_tensor
.
GetMultiIdxOffsets
(),
tensor_operation
::
element_wise
::
PassThrough
{}};
transfer
.
Run
(
in_grid_desc
,
src_slice_origin_idxs
,
src_tensor
.
GetBuffer
(),
out_grid_desc
,
dst_tensor
.
GetBuffer
());
}
else
if
constexpr
(
SrcTensorType
::
IsDynamicBuffer
&&
!
DstTensorType
::
IsDynamicBuffer
)
{
// Perform copy from DynamicBuffer to StaticBuffer
const
auto
src_dst_slice_origin
=
generate_tuple
([
&
](
auto
)
{
return
I0
;
},
Number
<
num_dims
>
{});
constexpr
auto
src_vector_tensor_lengths
=
generate_sequence_v2
(
[
&
](
auto
I
)
{
if
constexpr
(
I
==
VectorDim
)
{
return
Number
<
ScalarPerVector
>
{};
}
else
{
return
I1
;
}
},
Number
<
num_dims
>
{});
auto
transfer
=
ThreadwiseTensorSliceTransfer_v4r1
<
typename
SrcTensorType
::
TensorElementType
,
typename
DstTensorType
::
TensorElementType
,
remove_cvref_t
<
decltype
(
in_grid_desc
)
>
,
remove_cvref_t
<
decltype
(
out_grid_desc
)
>
,
decltype
(
thread_slice_lengths
),
decltype
(
dim_access_order
),
decltype
(
src_vector_tensor_lengths
),
decltype
(
dim_access_order
)
>
{
src_tensor
.
GetMultiIdxOffsets
()};
transfer
.
Run
(
in_grid_desc
,
src_dst_slice_origin
,
src_tensor
.
GetBuffer
(),
out_grid_desc
,
src_dst_slice_origin
,
dst_tensor
.
GetBuffer
());
}
else
{
// Perform copy between StaticBuffers
copy
(
src_tensor
,
dst_tensor
);
}
}
}
// namespace wrapper
}
// namespace ck
include/ck/wrapper/tensor.hpp
View file @
2fd6c6d4
This diff is collapsed.
Click to expand it.
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment