Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
874a78f9
Commit
874a78f9
authored
Feb 09, 2024
by
Jun Liu
Browse files
Merge branch 'amd-develop' into amd-master
parents
6368be50
2fd6c6d4
Changes
89
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2817 additions
and
50 deletions
+2817
-50
example/01_gemm/gemm_xdl_int4.cpp
example/01_gemm/gemm_xdl_int4.cpp
+2
-3
example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_int4.cpp
..._gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_int4.cpp
+2
-3
example/10_convnd_fwd_multiple_d_multiple_reduce/convnd_fwd_max_xdl_int4.cpp
...wd_multiple_d_multiple_reduce/convnd_fwd_max_xdl_int4.cpp
+2
-3
example/18_batched_gemm_reduce/batched_gemm_reduce_xdl_fp16.cpp
...e/18_batched_gemm_reduce/batched_gemm_reduce_xdl_fp16.cpp
+4
-5
example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_xdl_int4.cpp
...wd_multiple_d/grouped_conv_fwd_bias_relu_add_xdl_int4.cpp
+2
-3
example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_int4.cpp
example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_int4.cpp
+2
-3
example/35_splitK_gemm/run_splitK_gemm_example.inc
example/35_splitK_gemm/run_splitK_gemm_example.inc
+1
-1
example/35_splitK_gemm/splitK_gemm_xdl_fp16.cpp
example/35_splitK_gemm/splitK_gemm_xdl_fp16.cpp
+1
-1
example/41_grouped_conv_conv_fwd/grouped_conv_conv_fwd_xdl_int4.cpp
..._grouped_conv_conv_fwd/grouped_conv_conv_fwd_xdl_int4.cpp
+2
-3
example/48_pool3d_fwd/pool3d_fwd_common.hpp
example/48_pool3d_fwd/pool3d_fwd_common.hpp
+4
-0
example/51_avgpool3d_bwd/avgpool3d_bwd_common.hpp
example/51_avgpool3d_bwd/avgpool3d_bwd_common.hpp
+4
-0
include/ck/ck.hpp
include/ck/ck.hpp
+1
-1
include/ck/host_utility/hip_check_error.hpp
include/ck/host_utility/hip_check_error.hpp
+15
-13
include/ck/host_utility/kernel_launch.hpp
include/ck/host_utility/kernel_launch.hpp
+8
-5
include/ck/stream_config.hpp
include/ck/stream_config.hpp
+2
-2
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops.hpp
...or_operation/gpu/block/blockwise_gemm_pipeline_xdlops.hpp
+999
-0
include/ck/tensor_operation/gpu/device/impl/device_contraction_utils.hpp
...or_operation/gpu/device/impl/device_contraction_utils.hpp
+6
-4
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v2.hpp
...operation/gpu/device/impl/device_gemm_xdl_cshuffle_v2.hpp
+306
-0
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
+301
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v2.hpp
...nsor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v2.hpp
+1153
-0
No files found.
example/01_gemm/gemm_xdl_int4.cpp
View file @
874a78f9
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
#error Should compile this file with ck::int4_t support
#endif
#include "common.hpp"
#include "common.hpp"
...
@@ -44,3 +42,4 @@ using ReferenceGemmInstance = ck::tensor_operation::host::
...
@@ -44,3 +42,4 @@ using ReferenceGemmInstance = ck::tensor_operation::host::
#include "run_gemm_example.inc"
#include "run_gemm_example.inc"
int
main
(
int
argc
,
char
*
argv
[])
{
return
!
run_gemm_example
(
argc
,
argv
);
}
int
main
(
int
argc
,
char
*
argv
[])
{
return
!
run_gemm_example
(
argc
,
argv
);
}
#endif
\ No newline at end of file
example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_int4.cpp
View file @
874a78f9
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
#error Should compile this file with ck::int4_t support
#endif
#include "common.hpp"
#include "common.hpp"
...
@@ -58,3 +56,4 @@ using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataTyp
...
@@ -58,3 +56,4 @@ using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataTyp
#include "run_gemm_add_add_fastgelu_example.inc"
#include "run_gemm_add_add_fastgelu_example.inc"
int
main
(
int
argc
,
char
*
argv
[])
{
return
!
run_gemm_add_add_fastgelu_example
(
argc
,
argv
);
}
int
main
(
int
argc
,
char
*
argv
[])
{
return
!
run_gemm_add_add_fastgelu_example
(
argc
,
argv
);
}
#endif
example/10_convnd_fwd_multiple_d_multiple_reduce/convnd_fwd_max_xdl_int4.cpp
View file @
874a78f9
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
#error Should compile this file with ck::int4_t support
#endif
#define BUILD_INT4_EXAMPLE
#define BUILD_INT4_EXAMPLE
...
@@ -24,3 +22,4 @@ using RsDataType = ck::Tuple<R0DataType>;
...
@@ -24,3 +22,4 @@ using RsDataType = ck::Tuple<R0DataType>;
#include "run_convnd_fwd_max_example.inc"
#include "run_convnd_fwd_max_example.inc"
int
main
(
int
argc
,
char
*
argv
[])
{
return
!
run_convnd_fwd_max_example
(
argc
,
argv
);
}
int
main
(
int
argc
,
char
*
argv
[])
{
return
!
run_convnd_fwd_max_example
(
argc
,
argv
);
}
#endif
example/18_batched_gemm_reduce/batched_gemm_reduce_xdl_fp16.cpp
View file @
874a78f9
...
@@ -272,15 +272,14 @@ int main(int argc, char* argv[])
...
@@ -272,15 +272,14 @@ int main(int argc, char* argv[])
{
{
for
(
int
m
=
0
;
m
<
M
;
++
m
)
for
(
int
m
=
0
;
m
<
M
;
++
m
)
{
{
auto
reduce0_acc
=
reduce0_op
.
GetIdentityValue
<
ReduceAccDataType
>
();
auto
reduce0_acc
=
reduce0_op
.
GetIdentityValue
<
ReduceAccDataType
>
();
auto
reduce1_acc
=
reduce1_op
.
GetIdentityValue
<
ReduceAccDataType
>
();
auto
reduce1_acc
=
reduce1_op
.
GetIdentityValue
<
ReduceAccDataType
>
();
ReduceAccDataType
d0_val
=
0
;
ReduceAccDataType
d1_val
=
0
;
for
(
int
n
=
0
;
n
<
N
;
++
n
)
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
{
auto
c_val
=
auto
c_val
=
ck
::
type_convert
<
ReduceAccDataType
>
(
c_g_m_n_host_result
(
batch
,
m
,
n
));
ck
::
type_convert
<
ReduceAccDataType
>
(
c_g_m_n_host_result
(
batch
,
m
,
n
));
ReduceAccDataType
d0_val
;
ReduceAccDataType
d1_val
;
UnaryIdenticElementOp
{}(
d0_val
,
c_val
);
UnaryIdenticElementOp
{}(
d0_val
,
c_val
);
UnarySquareElementOp
{}(
d1_val
,
c_val
);
UnarySquareElementOp
{}(
d1_val
,
c_val
);
...
...
example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_xdl_int4.cpp
View file @
874a78f9
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
#error Should compile this file with ck::int4_t support
#endif
#include "common.hpp"
#include "common.hpp"
...
@@ -29,3 +27,4 @@ using OutElementOp = ck::tensor_operation::element_wise::AddReluAdd;
...
@@ -29,3 +27,4 @@ using OutElementOp = ck::tensor_operation::element_wise::AddReluAdd;
#include "run_grouped_conv_fwd_bias_relu_add_example.inc"
#include "run_grouped_conv_fwd_bias_relu_add_example.inc"
int
main
(
int
argc
,
char
*
argv
[])
{
return
!
run_grouped_conv_fwd_bias_relu_add_example
(
argc
,
argv
);
}
int
main
(
int
argc
,
char
*
argv
[])
{
return
!
run_grouped_conv_fwd_bias_relu_add_example
(
argc
,
argv
);
}
#endif
example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_int4.cpp
View file @
874a78f9
...
@@ -9,9 +9,7 @@ Gemm + Gemm fused operation. Computes C_m_o = A_m_k * B0_k_n * B1_n_o
...
@@ -9,9 +9,7 @@ Gemm + Gemm fused operation. Computes C_m_o = A_m_k * B0_k_n * B1_n_o
Gemm1
Gemm1
*/
*/
#ifndef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
#error Should compile this file with ck::int4_t support
#endif
#include <iostream>
#include <iostream>
#include <numeric>
#include <numeric>
...
@@ -144,3 +142,4 @@ static_assert(sizeof(ck::int4_t) == sizeof(int8_t));
...
@@ -144,3 +142,4 @@ static_assert(sizeof(ck::int4_t) == sizeof(int8_t));
#endif
#endif
int
main
(
int
argc
,
char
*
argv
[])
{
return
run_batched_gemm_gemm_example
(
argc
,
argv
)
?
0
:
1
;
}
int
main
(
int
argc
,
char
*
argv
[])
{
return
run_batched_gemm_gemm_example
(
argc
,
argv
)
?
0
:
1
;
}
#endif
example/35_splitK_gemm/run_splitK_gemm_example.inc
View file @
874a78f9
...
@@ -157,7 +157,7 @@ bool run_splitK_gemm(const ProblemSize& problem_size, const ExecutionConfig& con
...
@@ -157,7 +157,7 @@ bool run_splitK_gemm(const ProblemSize& problem_size, const ExecutionConfig& con
if
(
config
.
time_kernel
)
if
(
config
.
time_kernel
)
{
{
float
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
config
.
time_kernel
});
float
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
config
.
time_kernel
,
1
});
std
::
size_t
flop
=
std
::
size_t
(
2
)
*
M
*
N
*
K
;
std
::
size_t
flop
=
std
::
size_t
(
2
)
*
M
*
N
*
K
;
std
::
size_t
num_btype
=
std
::
size_t
num_btype
=
...
...
example/35_splitK_gemm/splitK_gemm_xdl_fp16.cpp
View file @
874a78f9
...
@@ -42,7 +42,7 @@ using AElementOp = PassThrough;
...
@@ -42,7 +42,7 @@ using AElementOp = PassThrough;
using
BElementOp
=
PassThrough
;
using
BElementOp
=
PassThrough
;
using
CElementOp
=
PassThrough
;
using
CElementOp
=
PassThrough
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
KPadding
;
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGemmXdlSplitKCShuffle
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGemmXdlSplitKCShuffle
// clang-format off
// clang-format off
...
...
example/41_grouped_conv_conv_fwd/grouped_conv_conv_fwd_xdl_int4.cpp
View file @
874a78f9
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
#error Should compile this file with ck::int4_t support
#endif
#include <cstdlib>
#include <cstdlib>
#include <iostream>
#include <iostream>
...
@@ -120,3 +118,4 @@ static_assert(sizeof(ck::int4_t) == sizeof(int8_t));
...
@@ -120,3 +118,4 @@ static_assert(sizeof(ck::int4_t) == sizeof(int8_t));
#endif
#endif
int
main
(
int
argc
,
char
*
argv
[])
{
return
run_grouped_conv_conv_fwd_example
(
argc
,
argv
)
?
0
:
1
;
}
int
main
(
int
argc
,
char
*
argv
[])
{
return
run_grouped_conv_conv_fwd_example
(
argc
,
argv
)
?
0
:
1
;
}
#endif
example/48_pool3d_fwd/pool3d_fwd_common.hpp
View file @
874a78f9
...
@@ -32,6 +32,8 @@ std::vector<ck::index_t> f_tensor_strides_ncdhw(ck::index_t N_,
...
@@ -32,6 +32,8 @@ std::vector<ck::index_t> f_tensor_strides_ncdhw(ck::index_t N_,
return
{
C_
*
D
*
H
*
W
,
D
*
H
*
W
,
H
*
W
,
W
,
1
_uz
};
return
{
C_
*
D
*
H
*
W
,
D
*
H
*
W
,
H
*
W
,
W
,
1
_uz
};
else
if
constexpr
(
ck
::
is_same
<
decltype
(
layout
),
ck
::
tensor_layout
::
convolution
::
NDHWC
>::
value
)
else
if
constexpr
(
ck
::
is_same
<
decltype
(
layout
),
ck
::
tensor_layout
::
convolution
::
NDHWC
>::
value
)
return
{
D
*
C_
*
H
*
W
,
1
_uz
,
C_
*
H
*
W
,
W
*
C_
,
C_
};
return
{
D
*
C_
*
H
*
W
,
1
_uz
,
C_
*
H
*
W
,
W
*
C_
,
C_
};
throw
std
::
runtime_error
(
"Pool3d_fwd: problem with layout. "
);
return
{
0
,
0
,
0
,
0
,
0
};
};
};
template
<
typename
TensorLayout
>
template
<
typename
TensorLayout
>
...
@@ -53,6 +55,8 @@ HostTensorDescriptor f_host_tensor_descriptor(std::size_t N_,
...
@@ -53,6 +55,8 @@ HostTensorDescriptor f_host_tensor_descriptor(std::size_t N_,
return
HostTensorDescriptor
({
N_
,
C_
,
D
,
H
,
W
},
return
HostTensorDescriptor
({
N_
,
C_
,
D
,
H
,
W
},
{
D
*
C_
*
H
*
W
,
1
_uz
,
C_
*
H
*
W
,
W
*
C_
,
C_
});
{
D
*
C_
*
H
*
W
,
1
_uz
,
C_
*
H
*
W
,
W
*
C_
,
C_
});
}
}
throw
std
::
runtime_error
(
"Pool3d_fwd: problem with layout. "
);
return
HostTensorDescriptor
({
0
,
0
,
0
,
0
,
0
},
{
0
,
0
,
0
,
0
,
0
});
};
};
template
<
typename
DevicePoolFwdInstance
,
template
<
typename
DevicePoolFwdInstance
,
...
...
example/51_avgpool3d_bwd/avgpool3d_bwd_common.hpp
View file @
874a78f9
...
@@ -26,6 +26,8 @@ std::vector<ck::index_t> f_tensor_strides_ncdhw(ck::index_t N_,
...
@@ -26,6 +26,8 @@ std::vector<ck::index_t> f_tensor_strides_ncdhw(ck::index_t N_,
return
{
C_
*
D
*
H
*
W
,
D
*
H
*
W
,
H
*
W
,
W
,
1
_uz
};
return
{
C_
*
D
*
H
*
W
,
D
*
H
*
W
,
H
*
W
,
W
,
1
_uz
};
else
if
constexpr
(
ck
::
is_same
<
decltype
(
layout
),
ck
::
tensor_layout
::
convolution
::
NDHWC
>::
value
)
else
if
constexpr
(
ck
::
is_same
<
decltype
(
layout
),
ck
::
tensor_layout
::
convolution
::
NDHWC
>::
value
)
return
{
D
*
C_
*
H
*
W
,
1
_uz
,
C_
*
H
*
W
,
W
*
C_
,
C_
};
return
{
D
*
C_
*
H
*
W
,
1
_uz
,
C_
*
H
*
W
,
W
*
C_
,
C_
};
throw
std
::
runtime_error
(
"Avgpool3d_bwd: problem with layout. "
);
return
{
0
,
0
,
0
,
0
,
0
};
};
};
template
<
typename
TensorLayout
>
template
<
typename
TensorLayout
>
...
@@ -47,6 +49,8 @@ HostTensorDescriptor f_host_tensor_descriptor(std::size_t N_,
...
@@ -47,6 +49,8 @@ HostTensorDescriptor f_host_tensor_descriptor(std::size_t N_,
return
HostTensorDescriptor
({
N_
,
C_
,
D
,
H
,
W
},
return
HostTensorDescriptor
({
N_
,
C_
,
D
,
H
,
W
},
{
D
*
C_
*
H
*
W
,
1
_uz
,
C_
*
H
*
W
,
W
*
C_
,
C_
});
{
D
*
C_
*
H
*
W
,
1
_uz
,
C_
*
H
*
W
,
W
*
C_
,
C_
});
}
}
throw
std
::
runtime_error
(
"Avgpool3d_bwd: problem with layout. "
);
return
HostTensorDescriptor
({
0
,
0
,
0
,
0
,
0
},
{
0
,
0
,
0
,
0
,
0
});
};
};
template
<
typename
DevicePoolBwdInstance
,
template
<
typename
DevicePoolBwdInstance
,
...
...
include/ck/ck.hpp
View file @
874a78f9
...
@@ -218,7 +218,7 @@
...
@@ -218,7 +218,7 @@
// denorm test fix, required to work around dissue
// denorm test fix, required to work around dissue
#ifndef CK_WORKAROUND_DENORM_FIX
#ifndef CK_WORKAROUND_DENORM_FIX
#define CK_WORKAROUND_DENORM_FIX 0
#define CK_WORKAROUND_DENORM_FIX 0
#el
if
#el
se
// enable only on MI200
// enable only on MI200
#define CK_WORKAROUND_DENORM_FIX = CK_WORKAROUND_DENORM_FIX && defined(__gfx90a__)
#define CK_WORKAROUND_DENORM_FIX = CK_WORKAROUND_DENORM_FIX && defined(__gfx90a__)
#endif // CK_WORKAROUND_DENORM_FIX
#endif // CK_WORKAROUND_DENORM_FIX
...
...
include/ck/host_utility/hip_check_error.hpp
View file @
874a78f9
...
@@ -12,21 +12,23 @@ inline void hip_check_error(hipError_t x)
...
@@ -12,21 +12,23 @@ inline void hip_check_error(hipError_t x)
if
(
x
!=
hipSuccess
)
if
(
x
!=
hipSuccess
)
{
{
std
::
ostringstream
ss
;
std
::
ostringstream
ss
;
ss
<<
"HIP runtime error: "
<<
hipGetErrorString
(
x
)
<<
". "
<<
__FILE__
<<
": "
<<
__LINE__
ss
<<
"HIP runtime error: "
<<
hipGetErrorString
(
x
)
<<
". "
<<
"in function: "
<<
__func__
;
<<
"hip_check_error.hpp"
<<
": "
<<
__LINE__
<<
"in function: "
<<
__func__
;
throw
std
::
runtime_error
(
ss
.
str
());
throw
std
::
runtime_error
(
ss
.
str
());
}
}
}
}
#define HIP_CHECK_ERROR(retval_or_funcall) \
#define HIP_CHECK_ERROR(retval_or_funcall) \
do \
do \
{ \
{ \
hipError_t _tmpVal = retval_or_funcall; \
hipError_t _tmpVal = retval_or_funcall; \
if(_tmpVal != hipSuccess) \
if(_tmpVal != hipSuccess) \
{ \
{ \
std::ostringstream ostr; \
std::ostringstream ostr; \
ostr << "HIP Function Failed (" << __FILE__ << "," << __LINE__ << ") " \
ostr << "HIP Function Failed (" \
<< hipGetErrorString(_tmpVal); \
<< "hip_check_error.hpp" \
throw std::runtime_error(ostr.str()); \
<< "," << __LINE__ << ") " << hipGetErrorString(_tmpVal); \
} \
throw std::runtime_error(ostr.str()); \
} \
} while(0)
} while(0)
include/ck/host_utility/kernel_launch.hpp
View file @
874a78f9
...
@@ -30,7 +30,7 @@ float launch_and_time_kernel(const StreamConfig& stream_config,
...
@@ -30,7 +30,7 @@ float launch_and_time_kernel(const StreamConfig& stream_config,
block_dim
.
y
,
block_dim
.
y
,
block_dim
.
z
);
block_dim
.
z
);
printf
(
"Warm up
1
time
\n
"
);
printf
(
"Warm up
%d
time
s
\n
"
,
stream_config
.
cold_niters_
);
#endif
#endif
// warm up
// warm up
for
(
int
i
=
0
;
i
<
stream_config
.
cold_niters_
;
++
i
)
for
(
int
i
=
0
;
i
<
stream_config
.
cold_niters_
;
++
i
)
...
@@ -103,14 +103,17 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
...
@@ -103,14 +103,17 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
block_dim
.
y
,
block_dim
.
y
,
block_dim
.
z
);
block_dim
.
z
);
printf
(
"Warm up
1
time
\n
"
);
printf
(
"Warm up
%d
time
s
\n
"
,
stream_config
.
cold_niters_
);
#endif
#endif
// warm up
// warm up
preprocess
();
preprocess
();
kernel
<<<
grid_dim
,
block_dim
,
lds_byte
,
stream_config
.
stream_id_
>>>
(
args
...);
for
(
int
i
=
0
;
i
<
stream_config
.
cold_niters_
;
++
i
)
hip_check_error
(
hipGetLastError
());
{
kernel
<<<
grid_dim
,
block_dim
,
lds_byte
,
stream_config
.
stream_id_
>>>
(
args
...);
hip_check_error
(
hipGetLastError
());
}
const
int
nrepeat
=
10
;
const
int
nrepeat
=
stream_config
.
nrepeat_
;
#if DEBUG_LOG
#if DEBUG_LOG
printf
(
"Start running %d times...
\n
"
,
nrepeat
);
printf
(
"Start running %d times...
\n
"
,
nrepeat
);
#endif
#endif
...
...
include/ck/stream_config.hpp
View file @
874a78f9
...
@@ -11,6 +11,6 @@ struct StreamConfig
...
@@ -11,6 +11,6 @@ struct StreamConfig
hipStream_t
stream_id_
=
nullptr
;
hipStream_t
stream_id_
=
nullptr
;
bool
time_kernel_
=
false
;
bool
time_kernel_
=
false
;
int
log_level_
=
0
;
int
log_level_
=
0
;
int
cold_niters_
=
1
;
int
cold_niters_
=
5
;
int
nrepeat_
=
1
0
;
int
nrepeat_
=
5
0
;
};
};
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops.hpp
0 → 100644
View file @
874a78f9
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/device/impl/device_contraction_utils.hpp
View file @
874a78f9
...
@@ -35,15 +35,17 @@ auto CalculateMaxRead(const std::vector<index_t>& lengths, const std::vector<ind
...
@@ -35,15 +35,17 @@ auto CalculateMaxRead(const std::vector<index_t>& lengths, const std::vector<ind
if
(
lengths
.
size
()
!=
NumDim1
+
NumDim2
)
if
(
lengths
.
size
()
!=
NumDim1
+
NumDim2
)
{
{
std
::
ostringstream
err
;
std
::
ostringstream
err
;
err
<<
"Incorrect number of lengths in "
<<
__FILE__
<<
":"
<<
__LINE__
err
<<
"Incorrect number of lengths in "
<<
", in function: "
<<
__func__
;
<<
"device_contraction_utils.hpp"
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
;
throw
std
::
runtime_error
(
err
.
str
());
throw
std
::
runtime_error
(
err
.
str
());
}
}
if
(
strides
.
size
()
!=
NumDim1
+
NumDim2
)
if
(
strides
.
size
()
!=
NumDim1
+
NumDim2
)
{
{
std
::
ostringstream
err
;
std
::
ostringstream
err
;
err
<<
"Incorrect number of strides in "
<<
__FILE__
<<
":"
<<
__LINE__
err
<<
"Incorrect number of strides in "
<<
", in function: "
<<
__func__
;
<<
"device_contraction_utils.hpp"
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
;
throw
std
::
runtime_error
(
err
.
str
());
throw
std
::
runtime_error
(
err
.
str
());
}
}
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v2.hpp
0 → 100644
View file @
874a78f9
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v2.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
// Note: inter-wave loop scheduler is rolled out to c-shuffle version first. Becuase non c-shuffle
// version currently has compiler issues with register spill which further causes validation
// failures.
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
,
typename
ADataType
,
typename
BDataType
,
typename
CDataType
,
typename
GemmAccDataType
,
typename
CShuffleDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
GemmSpecialization
GemmSpec
,
index_t
NumGemmKPrefetchStage
,
index_t
BlockSize
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
AK1
,
index_t
BK1
,
index_t
MPerXDL
,
index_t
NPerXDL
,
index_t
MXdlPerWave
,
index_t
NXdlPerWave
,
typename
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
index_t
ABlockTransferSrcVectorDim
,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
ABlockTransferDstScalarPerVector_AK1
,
bool
ABlockLdsExtraM
,
typename
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
typename
BBlockTransferThreadClusterArrangeOrder
,
typename
BBlockTransferSrcAccessOrder
,
index_t
BBlockTransferSrcVectorDim
,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
BBlockTransferDstScalarPerVector_BK1
,
bool
BBlockLdsExtraN
,
index_t
CShuffleMXdlPerWavePerShuffle
,
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
LoopScheduler
LoopSched
=
make_default_loop_scheduler
(),
PipelineVersion
PipelineVer
=
PipelineVersion
::
v1
,
typename
ComputeTypeA
=
CDataType
,
typename
ComputeTypeB
=
ComputeTypeA
>
struct
DeviceGemm_Xdl_CShuffleV2
:
public
DeviceGemm
<
ALayout
,
BLayout
,
CLayout
,
ADataType
,
BDataType
,
CDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>
{
using
DeviceOp
=
DeviceGemm_Xdl_CShuffleV2
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
// GridwiseGemm
using
GridwiseGemm
=
GridwiseGemm_xdl_cshuffle_v2
<
ALayout
,
BLayout
,
CLayout
,
ADataType
,
BDataType
,
GemmAccDataType
,
CShuffleDataType
,
CDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
GemmSpec
,
InMemoryDataOperationEnum
::
Set
,
NumGemmKPrefetchStage
,
BlockSize
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
AK1
,
BK1
,
MPerXDL
,
NPerXDL
,
MXdlPerWave
,
NXdlPerWave
,
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcVectorDim
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_AK1
,
false
,
ABlockLdsExtraM
,
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcVectorDim
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_BK1
,
false
,
BBlockLdsExtraN
,
CShuffleMXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
LoopSched
,
PipelineVer
,
ComputeTypeA
,
ComputeTypeB
>
;
using
Argument
=
typename
GridwiseGemm
::
Argument
;
// Invoker
struct
Invoker
:
public
BaseInvoker
{
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
if
(
stream_config
.
log_level_
>
0
)
{
arg
.
Print
();
}
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
))
{
throw
std
::
runtime_error
(
"wrong! GridwiseGemm has invalid setting"
);
}
index_t
gdx
,
gdy
,
gdz
;
std
::
tie
(
gdx
,
gdy
,
gdz
)
=
GridwiseGemm
::
CalculateGridSize
(
arg
.
M
,
arg
.
N
);
float
ave_time
=
0
;
const
auto
K
=
GridwiseGemm
::
CalculateAK0
(
arg
.
K
)
*
AK1
;
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K
)
==
3
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v2
<
GridwiseGemm
,
true
>
;
ave_time
=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
gdx
,
gdy
,
gdz
),
dim3
(
BlockSize
),
0
,
arg
);
}
else
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v2
<
GridwiseGemm
,
true
,
2
>
;
ave_time
=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
gdx
,
gdy
,
gdz
),
dim3
(
BlockSize
),
0
,
arg
);
}
return
ave_time
;
}
// polymorphic
float
Run
(
const
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
);
}
};
static
constexpr
bool
IsValidCompilationParameter
()
{
// TODO: properly implement this check
return
true
;
}
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
if
(
!
ck
::
is_xdl_supported
())
{
return
false
;
}
if
((
arg
.
K
%
AK1
!=
0
||
arg
.
K
%
BK1
!=
0
)
&&
!
(
GemmSpec
==
GemmSpecialization
::
MKPadding
||
GemmSpec
==
GemmSpecialization
::
NKPadding
||
GemmSpec
==
GemmSpecialization
::
MNKPadding
||
GemmSpec
==
GemmSpecialization
::
KPadding
))
{
return
false
;
}
return
GridwiseGemm
::
CheckValidity
(
arg
);
}
// polymorphic
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
{
return
IsSupportedArgument
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
static
auto
MakeArgument
(
const
ADataType
*
p_a
,
const
BDataType
*
p_b
,
CDataType
*
p_c
,
index_t
M
,
index_t
N
,
index_t
K
,
index_t
StrideA
,
index_t
StrideB
,
index_t
StrideC
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
)
{
return
Argument
{
p_a
,
p_b
,
p_c
,
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
// polymorphic
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
void
*
p_c
,
index_t
M
,
index_t
N
,
index_t
K
,
index_t
StrideA
,
index_t
StrideB
,
index_t
StrideC
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
)
override
{
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
ADataType
*>
(
p_a
),
static_cast
<
const
BDataType
*>
(
p_b
),
static_cast
<
CDataType
*>
(
p_c
),
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
);
}
// polymorphic
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
}
// polymorphic
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
std
::
map
<
LoopScheduler
,
std
::
string
>
LoopSchedToString
{
{
LoopScheduler
::
Default
,
"Default"
},
{
LoopScheduler
::
Interwave
,
"Interwave"
}};
std
::
map
<
PipelineVersion
,
std
::
string
>
PipelineVersionToString
{{
PipelineVersion
::
v1
,
"v1"
},
{
PipelineVersion
::
v2
,
"v2"
}};
// clang-format off
str
<<
"DeviceGemm_Xdl_CShuffleV2"
<<
"<"
<<
getGemmSpecializationString
(
GemmSpec
)
<<
", "
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
<<
NPerBlock
<<
", "
<<
KPerBlock
<<
", "
<<
AK1
<<
", "
<<
BK1
<<
", "
<<
MPerXDL
<<
", "
<<
NPerXDL
<<
", "
<<
MXdlPerWave
<<
", "
<<
NXdlPerWave
<<
", "
<<
ABlockTransferSrcScalarPerVector
<<
", "
<<
BBlockTransferSrcScalarPerVector
<<
", "
<<
CShuffleMXdlPerWavePerShuffle
<<
", "
<<
CShuffleNXdlPerWavePerShuffle
<<
">"
<<
" LoopScheduler: "
<<
LoopSchedToString
[
LoopSched
]
<<
", "
<<
"PipelineVersion: "
<<
PipelineVersionToString
[
PipelineVer
];
// clang-format on
return
str
.
str
();
}
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
View file @
874a78f9
...
@@ -134,6 +134,11 @@ struct BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, void>
...
@@ -134,6 +134,11 @@ struct BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, void>
__host__
__device__
BlockToCTileMap_M00_N0_M01Adapt
(
index_t
M
,
index_t
N
,
index_t
M01
=
8
)
__host__
__device__
BlockToCTileMap_M00_N0_M01Adapt
(
index_t
M
,
index_t
N
,
index_t
M01
=
8
)
:
M_
(
M
),
N_
(
N
),
M01_
(
M01
)
:
M_
(
M
),
N_
(
N
),
M01_
(
M01
)
{
{
#if 0
if(get_thread_global_1d_id()==0){
printf("Ctor called, M= %d, N= %d, M01 = %d\n", M_, N_, M01_);
}
#endif
}
}
template
<
typename
CGridDesc_M_N
>
template
<
typename
CGridDesc_M_N
>
...
@@ -252,6 +257,302 @@ struct BlockToCTileMap_M00_N0_M01Adapt : BlockToCTileMap_M00_N0_M01Adapt<MPerBlo
...
@@ -252,6 +257,302 @@ struct BlockToCTileMap_M00_N0_M01Adapt : BlockToCTileMap_M00_N0_M01Adapt<MPerBlo
BlockToCTileMap_M00_N0_M01Adapt
;
BlockToCTileMap_M00_N0_M01Adapt
;
};
};
// Rows of column-vectors
// This C-tile map dynamically adjusts M01 when C-tile index is out of range
template
<
index_t
GroupNum
,
index_t
MPerBlock
,
index_t
NPerBlock
,
typename
CGridDesc_M_N
=
void
>
struct
BlockToCTileMap_Grouped_M00_N0_M01Adapt
;
template
<
index_t
GroupNum
,
index_t
MPerBlock
,
index_t
NPerBlock
>
struct
BlockToCTileMap_Grouped_M00_N0_M01Adapt
<
GroupNum
,
MPerBlock
,
NPerBlock
,
void
>
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
__host__
__device__
BlockToCTileMap_Grouped_M00_N0_M01Adapt
()
=
default
;
__host__
__device__
BlockToCTileMap_Grouped_M00_N0_M01Adapt
(
const
BlockToCTileMap_Grouped_M00_N0_M01Adapt
&
)
=
default
;
__host__
__device__
BlockToCTileMap_Grouped_M00_N0_M01Adapt
(
BlockToCTileMap_Grouped_M00_N0_M01Adapt
&&
)
=
default
;
__host__
__device__
BlockToCTileMap_Grouped_M00_N0_M01Adapt
&
operator
=
(
const
BlockToCTileMap_Grouped_M00_N0_M01Adapt
&
)
=
default
;
__host__
__device__
BlockToCTileMap_Grouped_M00_N0_M01Adapt
&
operator
=
(
BlockToCTileMap_Grouped_M00_N0_M01Adapt
&&
)
=
default
;
__host__
__device__
BlockToCTileMap_Grouped_M00_N0_M01Adapt
(
index_t
M
,
index_t
N
,
index_t
M01
=
8
)
:
M_
(
M
),
N_
(
N
),
M01_
(
M01
)
{
#if 0
if(get_thread_global_1d_id()==0){
printf("Ctor called, M= %d, N= %d, M01 = %d\n", M_, N_, M01_);
}
#endif
}
template
<
typename
CGridDesc_M_N
>
__host__
__device__
BlockToCTileMap_Grouped_M00_N0_M01Adapt
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
,
index_t
M01
=
8
)
:
BlockToCTileMap_Grouped_M00_N0_M01Adapt
(
c_grid_desc_m_n
.
GetLength
(
I0
),
c_grid_desc_m_n
.
GetLength
(
I1
),
M01
)
{
}
__host__
static
constexpr
index_t
CalculateGridSize
(
index_t
M
,
index_t
N
)
{
const
auto
M0
=
math
::
integer_divide_ceil
(
M
,
MPerBlock
);
const
auto
N0
=
math
::
integer_divide_ceil
(
N
,
NPerBlock
);
return
M0
*
N0
;
}
template
<
typename
CGridDesc_M_N
>
__host__
static
constexpr
index_t
CalculateGridSize
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
{
return
CalculateGridSize
(
c_grid_desc_m_n
.
GetLength
(
I0
),
c_grid_desc_m_n
.
GetLength
(
I1
));
}
template
<
typename
CGridDesc_M_N
>
__host__
bool
CheckValidity
(
const
CGridDesc_M_N
&
/* c_grid_desc_m_n */
)
const
{
return
true
;
}
template
<
typename
TopIdx
>
__host__
__device__
constexpr
auto
CalculateBottomIndex
(
const
TopIdx
&
idx_top
)
const
{
auto
block_1d_id
=
idx_top
[
I0
];
const
auto
M0
=
math
::
integer_divide_ceil
(
M_
,
MPerBlock
);
const
auto
N0
=
math
::
integer_divide_ceil
(
N_
,
NPerBlock
);
block_1d_id
=
block_1d_id
%
(
M0
*
N0
);
// swallow batch index
const
auto
group_size
=
math
::
integer_divide_ceil
(
M0
*
N0
,
GroupNum
);
auto
group_id
=
block_1d_id
%
GroupNum
;
auto
remap_block_1d_id
=
group_id
*
group_size
+
block_1d_id
/
GroupNum
;
index_t
idx_N0
=
remap_block_1d_id
%
N0
;
index_t
idx_M0
=
remap_block_1d_id
/
N0
;
const
auto
M01_adapt
=
(
idx_M0
<
M0
-
M0
%
M01_
)
?
M01_
:
M0
%
M01_
;
index_t
idx_M00
=
idx_M0
/
M01_
;
index_t
idx_M01
=
idx_M0
%
M01_
;
index_t
idx_N0_M01_local
=
idx_N0
+
idx_M01
*
N0
;
/**
* idxN0
*
* |< mtx N >|
*
* NPerBlock NPerBlock NPerBlock NPerBlock
* N_0 N_1 N_2 N_3
* - |-----------|-----------|-----------|-----|-----|-
* ^ | - - 0 |/----> 2 | | | |
* | | | / | | | | | M_0 MPerBlock
* | M | /| | | | | |
* |-0---|---/-|-----|-----|-----------|-----|-----|-
* | 1 | / | | | blockid | | |
* idxM0 | | | / | V | 5 | | | M_1 MPerBlock
* | - V 1 | - 3 | | | |
* |-----------|-----------|-----------|-----|-----|-
* mtx M | | | | | |
* | | | | | | M_2 MPerBlock
* | | | | | |
* |-----------|-----------|-----------|-----|-----|-
* | | | | | |
* | | | | | | M_3 MPerBlock
* | | | | | |
* |-----------|-----------|-----------|-----|-----|-
* V | | | | | |
* - |-----------|-----------|-----------|-----|-----|- M_4 MPerBlock
* | | | | | |
* |-----------|-----------|-----------|-----|-----|-
* Example:
* assume:
* M0 = 5
* N0 = 4
* block_1d_id = 5
* M01 = 2
*
* idx_N0 = 1
* idx_M0 = 1
* M01_adapt = 2
* idx_M00 = 0
* idx_M01 = 1
* idx_N0_M01_local = 5
* output {1, 2}
*/
return
make_tuple
(
idx_N0_M01_local
%
M01_adapt
+
idx_M00
*
M01_
,
idx_N0_M01_local
/
M01_adapt
);
}
template
<
typename
CTileIdx
,
typename
CTileDim
>
__host__
__device__
bool
ValidCTileIndex
(
const
CTileIdx
&
/* c_tile_idx */
,
const
CTileDim
&
/* c_tile_dim */
)
const
{
return
true
;
// always valid provided that user gets grid size from CalculateGridSize()
}
private:
index_t
M_
;
index_t
N_
;
index_t
M01_
;
};
// keep the redundant type argument for backward compatibility
template
<
index_t
GroupNum
,
index_t
MPerBlock
,
index_t
NPerBlock
,
typename
CGridDesc_M_N
>
struct
BlockToCTileMap_Grouped_M00_N0_M01Adapt
:
BlockToCTileMap_Grouped_M00_N0_M01Adapt
<
GroupNum
,
MPerBlock
,
NPerBlock
,
void
>
{
using
BlockToCTileMap_Grouped_M00_N0_M01Adapt
<
GroupNum
,
MPerBlock
,
NPerBlock
,
void
>::
BlockToCTileMap_Grouped_M00_N0_M01Adapt
;
};
// columns of row-vectors
// This C-tile map dynamically adjusts N01 when C-tile index is out of range
template
<
index_t
MPerBlock
,
index_t
NPerBlock
,
typename
CGridDesc_M_N
=
void
>
struct
BlockToCTileMap_N00_M0_N01Adapt
;
template
<
index_t
MPerBlock
,
index_t
NPerBlock
>
struct
BlockToCTileMap_N00_M0_N01Adapt
<
MPerBlock
,
NPerBlock
,
void
>
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
__host__
__device__
BlockToCTileMap_N00_M0_N01Adapt
()
=
default
;
__host__
__device__
BlockToCTileMap_N00_M0_N01Adapt
(
const
BlockToCTileMap_N00_M0_N01Adapt
&
)
=
default
;
__host__
__device__
BlockToCTileMap_N00_M0_N01Adapt
(
BlockToCTileMap_N00_M0_N01Adapt
&&
)
=
default
;
__host__
__device__
BlockToCTileMap_N00_M0_N01Adapt
&
operator
=
(
const
BlockToCTileMap_N00_M0_N01Adapt
&
)
=
default
;
__host__
__device__
BlockToCTileMap_N00_M0_N01Adapt
&
operator
=
(
BlockToCTileMap_N00_M0_N01Adapt
&&
)
=
default
;
__host__
__device__
BlockToCTileMap_N00_M0_N01Adapt
(
index_t
M
,
index_t
N
,
index_t
N01
=
8
)
:
M_
(
M
),
N_
(
N
),
N01_
(
N01
)
{
#if 0
if(get_thread_global_1d_id()==0){
printf("Ctor called, M= %d, N= %d, N01 = %d\n", M_, N_, N01_);
}
#endif
}
template
<
typename
CGridDesc_M_N
>
__host__
__device__
BlockToCTileMap_N00_M0_N01Adapt
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
,
index_t
N01
=
8
)
:
BlockToCTileMap_N00_M0_N01Adapt
(
c_grid_desc_m_n
.
GetLength
(
I0
),
c_grid_desc_m_n
.
GetLength
(
I1
),
N01
)
{
}
__host__
static
constexpr
index_t
CalculateGridSize
(
index_t
M
,
index_t
N
)
{
const
auto
M0
=
math
::
integer_divide_ceil
(
M
,
MPerBlock
);
const
auto
N0
=
math
::
integer_divide_ceil
(
N
,
NPerBlock
);
return
M0
*
N0
;
}
template
<
typename
CGridDesc_M_N
>
__host__
static
constexpr
index_t
CalculateGridSize
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
{
return
CalculateGridSize
(
c_grid_desc_m_n
.
GetLength
(
I0
),
c_grid_desc_m_n
.
GetLength
(
I1
));
}
template
<
typename
CGridDesc_M_N
>
__host__
bool
CheckValidity
(
const
CGridDesc_M_N
&
/* c_grid_desc_m_n */
)
const
{
return
true
;
}
template
<
typename
TopIdx
>
__host__
__device__
constexpr
auto
CalculateBottomIndex
(
const
TopIdx
&
idx_top
)
const
{
auto
block_1d_id
=
idx_top
[
I0
];
const
auto
M0
=
math
::
integer_divide_ceil
(
M_
,
MPerBlock
);
const
auto
N0
=
math
::
integer_divide_ceil
(
N_
,
NPerBlock
);
block_1d_id
=
block_1d_id
%
(
M0
*
N0
);
// swallow batch index
index_t
idx_M0
=
block_1d_id
%
M0
;
index_t
idx_N0
=
block_1d_id
/
M0
;
const
auto
N01_adapt
=
(
idx_N0
<
N0
-
N0
%
N01_
)
?
N01_
:
N0
%
N01_
;
index_t
idx_N00
=
idx_N0
/
N01_
;
index_t
idx_N01
=
idx_N0
%
N01_
;
index_t
idx_M0_N01_local
=
idx_M0
+
idx_N01
*
M0
;
/**
* idxN0
*
* |< mtx N >|
*
* |<---N01--->|
* - |-----------|-----------|-----------|-----|-----|-
* ^ | 0 ----------> 1 | | | |
* | | / | | | | M_0 MPerBlock
* | / | | | |
* |------/----------------|-----------|-----|-----|-
* | | | | | | |
* idxM0 | V | | | | | M_1 MPerBlock
* | 2 ----------> 3 | | | |
* |-----------|-----------|-----------|-----|-----|-
* mtx M | | blockid | | | |
* | | 5 | | | | M_2 MPerBlock
* | | | | | |
* |-----------|-----------|-----------|-----|-----|-
* | | | | | |
* | | | | | | M_3 MPerBlock
* | | | | | |
* |-----------|-----------|-----------|-----|-----|-
* V | | | | | |
* - |-----------|-----------|-----------|-----|-----|- M_4 MPerBlock
* | | | | | |
* |-----------|-----------|-----------|-----|-----|-
* NPerBlock NPerBlock NPerBlock NPerBlock
* N_0 N_1 N_2 N_3
* Example:
* assume:
* N0 = 5
* M0 = 4
* block_1d_id = 5
* N01 = 2
*
* idx_M0 = 1
* idx_N0 = 1
* N01_adapt = 2
* idx_N00 = 0
* idx_N01 = 1
* idx_M0_N01_local = 5
* output {2, 1}
*/
return
make_tuple
(
idx_M0_N01_local
/
N01_adapt
,
idx_M0_N01_local
%
N01_adapt
+
idx_N00
*
N01_
);
}
template
<
typename
CTileIdx
,
typename
CTileDim
>
__host__
__device__
bool
ValidCTileIndex
(
const
CTileIdx
&
/* c_tile_idx */
,
const
CTileDim
&
/* c_tile_dim */
)
const
{
return
true
;
// always valid provided that user gets grid size from CalculateGridSize()
}
private:
index_t
M_
;
index_t
N_
;
index_t
N01_
;
};
// 2D slices of column-vectors in 3D space
// 2D slices of column-vectors in 3D space
// This C-tile map dynamically adjusts M01 when C-tile index is out of range
// This C-tile map dynamically adjusts M01 when C-tile index is out of range
template
<
index_t
MPerBlock
,
index_t
NPerBlock
,
typename
CGridDesc_M_N
>
template
<
index_t
MPerBlock
,
index_t
NPerBlock
,
typename
CGridDesc_M_N
>
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v2.hpp
0 → 100644
View file @
874a78f9
This diff is collapsed.
Click to expand it.
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment