Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
4b0b327b
Commit
4b0b327b
authored
Oct 02, 2023
by
Umang Yadav
Browse files
merge migx-jit-lib-hiprtc branch
parent
ba251e4a
Changes
33
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
317 additions
and
196 deletions
+317
-196
CMakeLists.txt
CMakeLists.txt
+9
-9
include/ck/ck.hpp
include/ck/ck.hpp
+2
-1
include/ck/host_utility/device_prop.hpp
include/ck/host_utility/device_prop.hpp
+2
-0
include/ck/host_utility/kernel_launch.hpp
include/ck/host_utility/kernel_launch.hpp
+2
-1
include/ck/tensor_operation/gpu/device/device_base.hpp
include/ck/tensor_operation/gpu/device/device_base.hpp
+7
-4
include/ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm.hpp
...operation/gpu/device/device_batched_gemm_softmax_gemm.hpp
+4
-1
include/ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp
...ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp
+8
-5
include/ck/tensor_operation/gpu/device/gemm_specialization.hpp
...de/ck/tensor_operation/gpu/device/gemm_specialization.hpp
+2
-1
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp
...ce/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp
+83
-79
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp
...n/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp
+37
-32
include/ck/tensor_operation/gpu/device/masking_specialization.hpp
...ck/tensor_operation/gpu/device/masking_specialization.hpp
+2
-0
include/ck/tensor_operation/gpu/device/tensor_layout.hpp
include/ck/tensor_operation/gpu/device/tensor_layout.hpp
+1
-1
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
...or_operation/gpu/element/unary_element_wise_operation.hpp
+2
-1
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
+35
-32
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp
...or_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp
+4
-2
include/ck/utility/amd_wave_read_first_lane.hpp
include/ck/utility/amd_wave_read_first_lane.hpp
+13
-14
include/ck/utility/array.hpp
include/ck/utility/array.hpp
+1
-1
include/ck/utility/container_helper.hpp
include/ck/utility/container_helper.hpp
+2
-2
include/ck/utility/data_type.hpp
include/ck/utility/data_type.hpp
+100
-9
include/ck/utility/debug.hpp
include/ck/utility/debug.hpp
+1
-1
No files found.
CMakeLists.txt
View file @
4b0b327b
...
@@ -138,15 +138,15 @@ message("CMAKE_CXX_COMPILER_ID: ${CMAKE_CXX_COMPILER_ID}")
...
@@ -138,15 +138,15 @@ message("CMAKE_CXX_COMPILER_ID: ${CMAKE_CXX_COMPILER_ID}")
option
(
CK_BUILD_JIT_LIB,
"Only build the CK JIT Helper Library"
OFF
)
option
(
CK_BUILD_JIT_LIB,
"Only build the CK JIT Helper Library"
OFF
)
if
(
NOT CK_BUILD_JIT_LIB
)
if
(
NOT CK_BUILD_JIT_LIB
)
find_package
(
hip
)
find_package
(
hip
)
# No assumption that HIP kernels are launched with uniform block size for backward compatibility
# No assumption that HIP kernels are launched with uniform block size for backward compatibility
# SWDEV-413293 and https://reviews.llvm.org/D155213
# SWDEV-413293 and https://reviews.llvm.org/D155213
math
(
EXPR hip_VERSION_FLAT
"(
${
hip_VERSION_MAJOR
}
* 1000 +
${
hip_VERSION_MINOR
}
) * 100000 +
${
hip_VERSION_PATCH
}
"
)
math
(
EXPR hip_VERSION_FLAT
"(
${
hip_VERSION_MAJOR
}
* 1000 +
${
hip_VERSION_MINOR
}
) * 100000 +
${
hip_VERSION_PATCH
}
"
)
message
(
"hip_version_flat=
${
hip_VERSION_FLAT
}
"
)
message
(
"hip_version_flat=
${
hip_VERSION_FLAT
}
"
)
if
(
${
hip_VERSION_FLAT
}
GREATER 500723302
)
if
(
${
hip_VERSION_FLAT
}
GREATER 500723302
)
message
(
"Adding the fno-offload-uniform-block compiler flag"
)
message
(
"Adding the fno-offload-uniform-block compiler flag"
)
add_compile_options
(
-fno-offload-uniform-block
)
add_compile_options
(
-fno-offload-uniform-block
)
endif
()
endif
()
option
(
USE_BITINT_EXTENSION_INT4,
"Whether to enable clang's BitInt extension to provide int4 data type."
OFF
)
option
(
USE_BITINT_EXTENSION_INT4,
"Whether to enable clang's BitInt extension to provide int4 data type."
OFF
)
option
(
USE_OPT_NAVI3X,
"Whether to enable LDS cumode and Wavefront32 mode for NAVI3X silicons."
OFF
)
option
(
USE_OPT_NAVI3X,
"Whether to enable LDS cumode and Wavefront32 mode for NAVI3X silicons."
OFF
)
...
...
include/ck/ck.hpp
View file @
4b0b327b
...
@@ -4,11 +4,12 @@
...
@@ -4,11 +4,12 @@
#pragma once
#pragma once
#include "ck/config.h"
#include "ck/config.h"
#ifndef __HIPCC_RTC__
#ifndef CK_DONT_USE_HIP_RUNTIME_HEADERS
#ifndef CK_DONT_USE_HIP_RUNTIME_HEADERS
#include "hip/hip_runtime.h"
#include "hip/hip_runtime.h"
#include "hip/hip_fp16.h"
#include "hip/hip_fp16.h"
#endif
#endif
#endif
#define CK_TIME_KERNEL 1
#define CK_TIME_KERNEL 1
...
...
include/ck/host_utility/device_prop.hpp
View file @
4b0b327b
...
@@ -3,6 +3,7 @@
...
@@ -3,6 +3,7 @@
#pragma once
#pragma once
#ifndef __HIPCC_RTC__
#include <string>
#include <string>
#include <map>
#include <map>
#include <hip/hip_runtime.h>
#include <hip/hip_runtime.h>
...
@@ -59,3 +60,4 @@ inline bool is_xdl_supported()
...
@@ -59,3 +60,4 @@ inline bool is_xdl_supported()
}
}
}
// namespace ck
}
// namespace ck
#endif
include/ck/host_utility/kernel_launch.hpp
View file @
4b0b327b
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
#ifndef __HIPCC_RTC__
#include <hip/hip_runtime.h>
#include <hip/hip_runtime.h>
#include "ck/ck.hpp"
#include "ck/ck.hpp"
...
@@ -142,3 +142,4 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
...
@@ -142,3 +142,4 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
return
0
;
return
0
;
#endif
#endif
}
}
#endif
include/ck/tensor_operation/gpu/device/device_base.hpp
View file @
4b0b327b
...
@@ -2,16 +2,17 @@
...
@@ -2,16 +2,17 @@
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
#ifndef __HIPCC_RTC__
#include <string>
#include <string>
#include <sstream>
#include <sstream>
#include "ck/stream_config.hpp"
#include "ck/stream_config.hpp"
#endif
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
device
{
namespace
device
{
#ifndef __HIPCC_RTC__
struct
BaseArgument
struct
BaseArgument
{
{
BaseArgument
()
=
default
;
BaseArgument
()
=
default
;
...
@@ -36,6 +37,7 @@ struct BaseInvoker
...
@@ -36,6 +37,7 @@ struct BaseInvoker
virtual
~
BaseInvoker
()
{}
virtual
~
BaseInvoker
()
{}
};
};
#endif
struct
BaseOperator
struct
BaseOperator
{
{
...
@@ -43,7 +45,9 @@ struct BaseOperator
...
@@ -43,7 +45,9 @@ struct BaseOperator
BaseOperator
(
const
BaseOperator
&
)
=
default
;
BaseOperator
(
const
BaseOperator
&
)
=
default
;
BaseOperator
&
operator
=
(
const
BaseOperator
&
)
=
default
;
BaseOperator
&
operator
=
(
const
BaseOperator
&
)
=
default
;
#ifndef __HIPCC_RTC__
virtual
bool
IsSupportedArgument
(
const
BaseArgument
*
)
{
return
false
;
}
virtual
bool
IsSupportedArgument
(
const
BaseArgument
*
)
{
return
false
;
}
virtual
std
::
string
GetTypeString
()
const
{
return
""
;
}
virtual
std
::
string
GetTypeString
()
const
{
return
""
;
}
virtual
std
::
string
GetTypeIdName
()
const
{
return
typeid
(
*
this
).
name
();
}
virtual
std
::
string
GetTypeIdName
()
const
{
return
typeid
(
*
this
).
name
();
}
...
@@ -56,7 +60,6 @@ struct BaseOperator
...
@@ -56,7 +60,6 @@ struct BaseOperator
return
oss
.
str
();
return
oss
.
str
();
};
};
virtual
size_t
GetWorkSpaceSize
(
const
BaseArgument
*
)
const
{
return
0
;
}
virtual
size_t
GetWorkSpaceSize
(
const
BaseArgument
*
)
const
{
return
0
;
}
virtual
void
SetWorkSpacePointer
(
BaseArgument
*
p_arg
,
void
*
p_workspace
)
const
virtual
void
SetWorkSpacePointer
(
BaseArgument
*
p_arg
,
void
*
p_workspace
)
const
...
@@ -64,7 +67,7 @@ struct BaseOperator
...
@@ -64,7 +67,7 @@ struct BaseOperator
assert
(
p_arg
);
assert
(
p_arg
);
p_arg
->
p_workspace_
=
p_workspace
;
p_arg
->
p_workspace_
=
p_workspace
;
}
}
#endif
virtual
~
BaseOperator
()
{}
virtual
~
BaseOperator
()
{}
};
};
...
...
include/ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm.hpp
View file @
4b0b327b
...
@@ -2,9 +2,10 @@
...
@@ -2,9 +2,10 @@
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
#ifndef __HIPCC_RTC__
#include <iostream>
#include <iostream>
#include <vector>
#include <vector>
#endif
#include "device_base.hpp"
#include "device_base.hpp"
...
@@ -28,6 +29,7 @@ template <typename ALayout,
...
@@ -28,6 +29,7 @@ template <typename ALayout,
bool
MaskOutUpperTriangle
>
// TODO: enum for mask type
bool
MaskOutUpperTriangle
>
// TODO: enum for mask type
struct
DeviceBatchedGemmSoftmaxGemm
:
public
BaseOperator
struct
DeviceBatchedGemmSoftmaxGemm
:
public
BaseOperator
{
{
#ifndef __HIPCC_RTC__
virtual
std
::
unique_ptr
<
BaseArgument
>
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b0
,
const
void
*
p_b0
,
...
@@ -53,6 +55,7 @@ struct DeviceBatchedGemmSoftmaxGemm : public BaseOperator
...
@@ -53,6 +55,7 @@ struct DeviceBatchedGemmSoftmaxGemm : public BaseOperator
CElementwiseOperation
c_element_op
)
=
0
;
CElementwiseOperation
c_element_op
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
#endif
};
};
}
// namespace device
}
// namespace device
...
...
include/ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp
View file @
4b0b327b
...
@@ -2,9 +2,11 @@
...
@@ -2,9 +2,11 @@
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
#ifndef __HIPCC_RTC__
#include <array>
#include <array>
#endif
#include "ck/utility/array.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
namespace
ck
{
namespace
ck
{
...
@@ -34,23 +36,24 @@ struct DeviceGemmMultipleD : public BaseOperator
...
@@ -34,23 +36,24 @@ struct DeviceGemmMultipleD : public BaseOperator
{
{
static
constexpr
index_t
NumDTensor
=
DsDataType
::
Size
();
static
constexpr
index_t
NumDTensor
=
DsDataType
::
Size
();
#ifndef __HIPCC_RTC__
virtual
std
::
unique_ptr
<
BaseArgument
>
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
const
void
*
p_b
,
std
::
a
rray
<
const
void
*
,
NumDTensor
>
p_ds
,
ck
::
A
rray
<
const
void
*
,
NumDTensor
>
p_ds
,
void
*
p_e
,
void
*
p_e
,
ck
::
index_t
M
,
ck
::
index_t
M
,
ck
::
index_t
N
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
K
,
ck
::
index_t
StrideA
,
ck
::
index_t
StrideA
,
ck
::
index_t
StrideB
,
ck
::
index_t
StrideB
,
std
::
a
rray
<
ck
::
index_t
,
NumDTensor
>
StrideDs
,
ck
::
A
rray
<
ck
::
index_t
,
NumDTensor
>
StrideDs
,
ck
::
index_t
StrideE
,
ck
::
index_t
StrideE
,
AElementwiseOperation
a_element_op
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
BElementwiseOperation
b_element_op
,
CDEElementwiseOperation
cde_element_op
)
=
0
;
CDEElementwiseOperation
cde_element_op
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
#endif
};
};
}
// namespace device
}
// namespace device
...
...
include/ck/tensor_operation/gpu/device/gemm_specialization.hpp
View file @
4b0b327b
...
@@ -28,7 +28,7 @@ enum struct GemmSpecialization
...
@@ -28,7 +28,7 @@ enum struct GemmSpecialization
NKOPadding
,
NKOPadding
,
MNKOPadding
,
MNKOPadding
,
};
};
#ifndef __HIPCC_RTC__
inline
std
::
string
getGemmSpecializationString
(
const
GemmSpecialization
&
s
)
inline
std
::
string
getGemmSpecializationString
(
const
GemmSpecialization
&
s
)
{
{
switch
(
s
)
switch
(
s
)
...
@@ -52,6 +52,7 @@ inline std::string getGemmSpecializationString(const GemmSpecialization& s)
...
@@ -52,6 +52,7 @@ inline std::string getGemmSpecializationString(const GemmSpecialization& s)
default:
return
"Unrecognized specialization!"
;
default:
return
"Unrecognized specialization!"
;
}
}
}
}
#endif
}
// namespace device
}
// namespace device
}
// namespace tensor_operation
}
// namespace tensor_operation
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp
View file @
4b0b327b
...
@@ -3,8 +3,12 @@
...
@@ -3,8 +3,12 @@
#pragma once
#pragma once
#ifndef __HIPCC_RTC__
#include <iostream>
#include <iostream>
#include <sstream>
#include <sstream>
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#endif
#include "ck/utility/common_header.hpp"
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
...
@@ -15,8 +19,6 @@
...
@@ -15,8 +19,6 @@
#include "ck/tensor_operation/gpu/device/masking_specialization.hpp"
#include "ck/tensor_operation/gpu/device/masking_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
...
@@ -126,7 +128,6 @@ __global__ void
...
@@ -126,7 +128,6 @@ __global__ void
// else
// else
// AccElement = -INFINITY
// AccElement = -INFINITY
// Otherwise, result may be wrong.
// Otherwise, result may be wrong.
template
<
typename
ALayout
,
template
<
typename
ALayout
,
typename
BLayout
,
// B0Layout
typename
BLayout
,
// B0Layout
typename
B1Layout
,
typename
B1Layout
,
...
@@ -430,6 +431,7 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -430,6 +431,7 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
matrix_padder
.
PadN
,
matrix_padder
.
PadN
,
MaskOutUpperTriangle
>
;
MaskOutUpperTriangle
>
;
#ifndef __HIPCC_RTC__
// Argument
// Argument
struct
Argument
:
public
BaseArgument
struct
Argument
:
public
BaseArgument
{
{
...
@@ -604,14 +606,15 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -604,14 +606,15 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
);
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
);
}
}
};
};
#endif
static
constexpr
bool
IsValidCompilationParameter
()
static
constexpr
bool
IsValidCompilationParameter
()
{
{
// TODO: properly implement this check
// TODO: properly implement this check
return
true
;
return
true
;
}
}
static
constexpr
bool
IsSupported
(
index_t
MRaw_
,
index_t
NRaw_
,
index_t
KRaw_
,
index_t
Gemm1NRaw_
)
static
constexpr
bool
IsSupported
(
index_t
MRaw_
,
index_t
NRaw_
,
index_t
KRaw_
,
index_t
Gemm1NRaw_
)
{
{
// check vector load/store
// check vector load/store
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
...
@@ -699,7 +702,7 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -699,7 +702,7 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
return
true
;
return
true
;
}
}
#ifndef __HIPCC_RTC__
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
{
if
(
!
ck
::
is_xdl_supported
())
if
(
!
ck
::
is_xdl_supported
())
...
@@ -757,7 +760,6 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -757,7 +760,6 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
BatchStrideB1
,
BatchStrideC
,
a_element_op
,
b_element_op
,
acc_element_op
,
BatchStrideB1
,
BatchStrideC
,
a_element_op
,
b_element_op
,
acc_element_op
,
b1_element_op
,
c_element_op
};
b1_element_op
,
c_element_op
};
}
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
// polymorphic
// polymorphic
...
@@ -837,11 +839,11 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -837,11 +839,11 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
return
str
.
str
();
return
str
.
str
();
}
}
#endif
template
<
class
ADesc
,
class
BDesc
,
class
B1Desc
,
class
CDesc
>
template
<
class
ADesc
,
class
BDesc
,
class
B1Desc
,
class
CDesc
>
struct
Descriptor
struct
Descriptor
{
{
template
<
class
AGridDescriptor
>
template
<
class
AGridDescriptor
>
static
constexpr
auto
MakeAGridDescriptor_AK0_M_AK1
(
const
AGridDescriptor
&
a_grid_desc
)
static
constexpr
auto
MakeAGridDescriptor_AK0_M_AK1
(
const
AGridDescriptor
&
a_grid_desc
)
{
{
const
auto
a_grid_desc_m_k
=
DeviceOp
::
matrix_padder
.
PadADescriptor_M_K
(
a_grid_desc
);
const
auto
a_grid_desc_m_k
=
DeviceOp
::
matrix_padder
.
PadADescriptor_M_K
(
a_grid_desc
);
...
@@ -851,14 +853,15 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -851,14 +853,15 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
const
auto
AK0
=
K
/
AK1
;
const
auto
AK0
=
K
/
AK1
;
return
transform_tensor_descriptor
(
a_grid_desc_m_k
,
return
transform_tensor_descriptor
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1
)),
a_grid_desc_m_k
,
make_pass_through_transform
(
M
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_pass_through_transform
(
M
)),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
}
template
<
class
BGridDescriptor
>
template
<
class
BGridDescriptor
>
static
constexpr
auto
MakeBGridDescriptor_BK0_N_BK1
(
const
BGridDescriptor
&
b_grid_desc
)
static
constexpr
auto
MakeBGridDescriptor_BK0_N_BK1
(
const
BGridDescriptor
&
b_grid_desc
)
{
{
const
auto
b_grid_desc_n_k
=
DeviceOp
::
matrix_padder
.
PadBDescriptor_N_K
(
b_grid_desc
);
const
auto
b_grid_desc_n_k
=
DeviceOp
::
matrix_padder
.
PadBDescriptor_N_K
(
b_grid_desc
);
...
@@ -868,14 +871,15 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -868,14 +871,15 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
const
auto
BK0
=
K
/
BK1
;
const
auto
BK0
=
K
/
BK1
;
return
transform_tensor_descriptor
(
b_grid_desc_n_k
,
return
transform_tensor_descriptor
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1
)),
b_grid_desc_n_k
,
make_pass_through_transform
(
N
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
}
template
<
class
B1GridDescriptor
>
template
<
class
B1GridDescriptor
>
static
constexpr
auto
MakeB1GridDescriptor_BK0_N_BK1
(
const
B1GridDescriptor
&
b1_grid_desc
)
static
constexpr
auto
MakeB1GridDescriptor_BK0_N_BK1
(
const
B1GridDescriptor
&
b1_grid_desc
)
{
{
const
auto
b1_grid_desc_n_k
=
DeviceOp
::
matrix_padder
.
PadB1Descriptor_N_K
(
b1_grid_desc
);
const
auto
b1_grid_desc_n_k
=
DeviceOp
::
matrix_padder
.
PadB1Descriptor_N_K
(
b1_grid_desc
);
...
@@ -888,26 +892,24 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -888,26 +892,24 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
return
transform_tensor_descriptor
(
return
transform_tensor_descriptor
(
b1_grid_desc_n_k
,
b1_grid_desc_n_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
B1K0
,
B1K1
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
B1K0
,
B1K1
)),
make_pass_through_transform
(
N
)),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
}
template
<
class
CGridDescriptor
>
template
<
class
CGridDescriptor
>
static
constexpr
auto
MakeCGridDescriptor_M_N
(
const
CGridDescriptor
&
c_grid_desc
)
static
constexpr
auto
MakeCGridDescriptor_M_N
(
const
CGridDescriptor
&
c_grid_desc
)
{
{
return
DeviceOp
::
matrix_padder
.
PadCDescriptor_M_N
(
c_grid_desc
);
return
DeviceOp
::
matrix_padder
.
PadCDescriptor_M_N
(
c_grid_desc
);
}
}
using
AGridDesc_AK0_M_AK1
=
using
AGridDesc_AK0_M_AK1
=
remove_cvref_t
<
decltype
(
MakeAGridDescriptor_AK0_M_AK1
(
ADesc
{}))
>
;
remove_cvref_t
<
decltype
(
MakeAGridDescriptor_AK0_M_AK1
(
ADesc
{}))
>
;
using
BGridDesc_BK0_N_BK1
=
using
BGridDesc_BK0_N_BK1
=
remove_cvref_t
<
decltype
(
MakeBGridDescriptor_BK0_N_BK1
(
BDesc
{}))
>
;
remove_cvref_t
<
decltype
(
MakeBGridDescriptor_BK0_N_BK1
(
BDesc
{}))
>
;
using
B1GridDesc_BK0_N_BK1
=
using
B1GridDesc_BK0_N_BK1
=
remove_cvref_t
<
decltype
(
MakeB1GridDescriptor_BK0_N_BK1
(
B1Desc
{}))
>
;
remove_cvref_t
<
decltype
(
MakeB1GridDescriptor_BK0_N_BK1
(
B1Desc
{}))
>
;
using
CGridDesc_M_N
=
using
CGridDesc_M_N
=
remove_cvref_t
<
decltype
(
MakeCGridDescriptor_M_N
(
CDesc
{}))
>
;
remove_cvref_t
<
decltype
(
MakeCGridDescriptor_M_N
(
CDesc
{}))
>
;
// GridwiseGemm
// GridwiseGemm
using
GridwiseGemm
=
GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
<
using
GridwiseGemm
=
GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
<
...
@@ -978,8 +980,9 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -978,8 +980,9 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
CGridDesc_M_N
c_grid_desc_m_n
;
CGridDesc_M_N
c_grid_desc_m_n
;
C0MatrixMask
c0_matrix_mask
;
C0MatrixMask
c0_matrix_mask
;
typename
GridwiseGemm
::
DefaultBlock2CTileMap
block_2_ctile_map
;
typename
GridwiseGemm
::
DefaultBlock2CTileMap
block_2_ctile_map
;
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_descriptor_mblock_mperblock_nblock_nperblock
;
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_descriptor_mblock_mperblock_nblock_nperblock
;
// element-wise op
// element-wise op
AElementwiseOperation
a_element_op
;
AElementwiseOperation
a_element_op
;
BElementwiseOperation
b_element_op
;
BElementwiseOperation
b_element_op
;
...
@@ -1001,10 +1004,10 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -1001,10 +1004,10 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
b_grid_desc_bk0_n_bk1
{
MakeBGridDescriptor_BK0_N_BK1
(
b
)},
b_grid_desc_bk0_n_bk1
{
MakeBGridDescriptor_BK0_N_BK1
(
b
)},
b1_grid_desc_bk0_n_bk1
{
MakeB1GridDescriptor_BK0_N_BK1
(
b1
)},
b1_grid_desc_bk0_n_bk1
{
MakeB1GridDescriptor_BK0_N_BK1
(
b1
)},
c_grid_desc_m_n
{
MakeCGridDescriptor_M_N
(
c
)},
c_grid_desc_m_n
{
MakeCGridDescriptor_M_N
(
c
)},
block_2_ctile_map
{
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
block_2_ctile_map
{
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
c_grid_desc_m_n
)},
c_grid_desc_m_n
)},
c_grid_descriptor_mblock_mperblock_nblock_nperblock
{
c_grid_descriptor_mblock_mperblock_nblock_nperblock
{
GridwiseGemm
::
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
c_grid_desc_m_n
)},
GridwiseGemm
::
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
c_grid_desc_m_n
)},
has_main_k_block_loop
{
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
has_main_k_block_loop
{
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
))},
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
))},
c0_matrix_mask
{
c
.
GetLength
(
I1
)},
c0_matrix_mask
{
c
.
GetLength
(
I1
)},
...
@@ -1012,23 +1015,20 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -1012,23 +1015,20 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
b_element_op
{
b_element_op_
},
b_element_op
{
b_element_op_
},
b1_element_op
{
b1_element_op_
},
b1_element_op
{
b1_element_op_
},
c_element_op
{
c_element_op_
},
c_element_op
{
c_element_op_
},
is_valid
{
GridwiseGemm
::
CheckValidity
(
is_valid
{
GridwiseGemm
::
CheckValidity
(
a_grid_desc_ak0_m_ak1
,
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
b_grid_desc_bk0_n_bk1
,
b1_grid_desc_bk0_n_bk1
,
b1_grid_desc_bk0_n_bk1
,
c_grid_desc_m_n
,
c_grid_desc_m_n
,
block_2_ctile_map
)
and
block_2_ctile_map
)
and
IsSupported
(
a_grid_desc_ak0_m_ak1
.
GetLength
(
I1
),
IsSupported
(
a_grid_desc_ak0_m_ak1
.
GetLength
(
I1
),
b_grid_desc_bk0_n_bk1
.
GetLength
(
I1
),
b_grid_desc_bk0_n_bk1
.
GetLength
(
I1
),
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
),
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
),
b1_grid_desc_bk0_n_bk1
.
GetLength
(
I1
))}
b1_grid_desc_bk0_n_bk1
.
GetLength
(
I1
))}
{
{
}
}
constexpr
bool
IsValid
()
const
constexpr
bool
IsValid
()
const
{
return
is_valid
;
}
{
return
is_valid
;
}
};
};
template
<
class
ADesc
,
class
BDesc
,
class
B1Desc
,
class
CDesc
>
template
<
class
ADesc
,
class
BDesc
,
class
B1Desc
,
class
CDesc
>
...
@@ -1037,10 +1037,10 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -1037,10 +1037,10 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
BDesc
b
,
BDesc
b
,
B1Desc
b1
,
B1Desc
b1
,
CDesc
c
,
CDesc
c
,
AElementwiseOperation
a_element_op
=
AElementwiseOperation
{},
AElementwiseOperation
a_element_op
=
AElementwiseOperation
{},
BElementwiseOperation
b_element_op
=
BElementwiseOperation
{},
BElementwiseOperation
b_element_op
=
BElementwiseOperation
{},
B1ElementwiseOperation
b1_element_op
=
B1ElementwiseOperation
{},
B1ElementwiseOperation
b1_element_op
=
B1ElementwiseOperation
{},
CElementwiseOperation
c_element_op
=
CElementwiseOperation
{})
CElementwiseOperation
c_element_op
=
CElementwiseOperation
{})
{
{
return
Descriptor
<
ADesc
,
BDesc
,
B1Desc
,
CDesc
>
(
return
Descriptor
<
ADesc
,
BDesc
,
B1Desc
,
CDesc
>
(
a
,
b
,
b1
,
c
,
a_element_op
,
b_element_op
,
b1_element_op
,
c_element_op
);
a
,
b
,
b1
,
c
,
a_element_op
,
b_element_op
,
b1_element_op
,
c_element_op
);
...
@@ -1054,47 +1054,51 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -1054,47 +1054,51 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
const
ADataType
*
__restrict__
p_b1_grid
,
const
ADataType
*
__restrict__
p_b1_grid
,
CDataType
*
__restrict__
p_c_grid
)
CDataType
*
__restrict__
p_c_grid
)
{
{
#ifndef __HIPCC_RTC__
assert
(
desc
.
is_valid
);
assert
(
desc
.
is_valid
);
#endif
__shared__
char
p_shared_block
[
Desc
::
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared_block
[
Desc
::
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
AccElementwiseOperation
acc_element_op
{
scale
};
AccElementwiseOperation
acc_element_op
{
scale
};
if
(
desc
.
has_main_k_block_loop
)
if
(
desc
.
has_main_k_block_loop
)
{
{
Desc
::
GridwiseGemm
::
template
Run
<
true
>(
p_a_grid
,
Desc
::
GridwiseGemm
::
template
Run
<
true
>(
p_b_grid
,
p_a_grid
,
p_b1_grid
,
p_b_grid
,
p_c_grid
,
p_b1_grid
,
p_shared_block
,
p_c_grid
,
desc
.
a_element_op
,
p_shared_block
,
desc
.
b_element_op
,
desc
.
a_element_op
,
acc_element_op
,
desc
.
b_element_op
,
desc
.
b1_element_op
,
acc_element_op
,
desc
.
c_element_op
,
desc
.
b1_element_op
,
desc
.
a_grid_desc_ak0_m_ak1
,
desc
.
c_element_op
,
desc
.
b_grid_desc_bk0_n_bk1
,
desc
.
a_grid_desc_ak0_m_ak1
,
desc
.
b1_grid_desc_bk0_n_bk1
,
desc
.
b_grid_desc_bk0_n_bk1
,
desc
.
c_grid_descriptor_mblock_mperblock_nblock_nperblock
,
desc
.
b1_grid_desc_bk0_n_bk1
,
desc
.
block_2_ctile_map
,
desc
.
c_grid_descriptor_mblock_mperblock_nblock_nperblock
,
desc
.
c0_matrix_mask
);
desc
.
block_2_ctile_map
,
desc
.
c0_matrix_mask
);
}
}
else
else
{
{
Desc
::
GridwiseGemm
::
template
Run
<
false
>(
p_a_grid
,
Desc
::
GridwiseGemm
::
template
Run
<
false
>(
p_b_grid
,
p_a_grid
,
p_b1_grid
,
p_b_grid
,
p_c_grid
,
p_b1_grid
,
p_shared_block
,
p_c_grid
,
desc
.
a_element_op
,
p_shared_block
,
desc
.
b_element_op
,
desc
.
a_element_op
,
acc_element_op
,
desc
.
b_element_op
,
desc
.
b1_element_op
,
acc_element_op
,
desc
.
c_element_op
,
desc
.
b1_element_op
,
desc
.
a_grid_desc_ak0_m_ak1
,
desc
.
c_element_op
,
desc
.
b_grid_desc_bk0_n_bk1
,
desc
.
a_grid_desc_ak0_m_ak1
,
desc
.
b1_grid_desc_bk0_n_bk1
,
desc
.
b_grid_desc_bk0_n_bk1
,
desc
.
c_grid_descriptor_mblock_mperblock_nblock_nperblock
,
desc
.
b1_grid_desc_bk0_n_bk1
,
desc
.
block_2_ctile_map
,
desc
.
c_grid_descriptor_mblock_mperblock_nblock_nperblock
,
desc
.
c0_matrix_mask
);
desc
.
block_2_ctile_map
,
desc
.
c0_matrix_mask
);
}
}
}
}
};
};
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp
View file @
4b0b327b
...
@@ -2,20 +2,22 @@
...
@@ -2,20 +2,22 @@
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
#ifndef __HIPCC_RTC__
#include <iostream>
#include <iostream>
#include <sstream>
#include <sstream>
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#endif
#include "ck/utility/
common_header
.hpp"
#include "ck/utility/
array
.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/utility/common_header.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace
ck
{
namespace
ck
{
...
@@ -223,9 +225,9 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
...
@@ -223,9 +225,9 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
return
matrix_padder
.
PadCDescriptor_M_N
(
e_grid_desc_mraw_nraw
);
return
matrix_padder
.
PadCDescriptor_M_N
(
e_grid_desc_mraw_nraw
);
}
}
static
auto
MakeDsGridDescriptor_M_N
(
const
std
::
a
rray
<
index_t
,
NumDTensor
>&
MRaws
,
static
auto
MakeDsGridDescriptor_M_N
(
const
ck
::
A
rray
<
index_t
,
NumDTensor
>&
MRaws
,
const
std
::
a
rray
<
index_t
,
NumDTensor
>&
NRaws
,
const
ck
::
A
rray
<
index_t
,
NumDTensor
>&
NRaws
,
const
std
::
a
rray
<
index_t
,
NumDTensor
>&
DsStride
)
const
ck
::
A
rray
<
index_t
,
NumDTensor
>&
DsStride
)
{
{
return
generate_tuple
(
return
generate_tuple
(
[
&
](
auto
i
)
{
[
&
](
auto
i
)
{
...
@@ -304,20 +306,20 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
...
@@ -304,20 +306,20 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
// block-to-e-tile map
// block-to-e-tile map
using
Block2ETileMap
=
using
Block2ETileMap
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDefaultBlock2ETileMap
(
EGridDesc_M_N
{}))
>
;
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDefaultBlock2ETileMap
(
EGridDesc_M_N
{}))
>
;
#ifndef __HIPCC_RTC__
// Argument
// Argument
struct
Argument
:
public
BaseArgument
struct
Argument
:
public
BaseArgument
{
{
Argument
(
const
void
*
p_a_grid
,
Argument
(
const
void
*
p_a_grid
,
const
void
*
p_b_grid
,
const
void
*
p_b_grid
,
std
::
a
rray
<
const
void
*
,
NumDTensor
>
p_ds_grid
,
ck
::
A
rray
<
const
void
*
,
NumDTensor
>
p_ds_grid
,
void
*
p_e_grid
,
void
*
p_e_grid
,
index_t
MRaw
,
index_t
MRaw
,
index_t
NRaw
,
index_t
NRaw
,
index_t
KRaw
,
index_t
KRaw
,
index_t
StrideA
,
index_t
StrideA
,
index_t
StrideB
,
index_t
StrideB
,
std
::
a
rray
<
index_t
,
NumDTensor
>
StrideDs
,
ck
::
A
rray
<
index_t
,
NumDTensor
>
StrideDs
,
index_t
StrideE
,
index_t
StrideE
,
AElementwiseOperation
a_element_op
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
BElementwiseOperation
b_element_op
,
...
@@ -416,7 +418,6 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
...
@@ -416,7 +418,6 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
index_t
NRaw_
;
index_t
NRaw_
;
index_t
KRaw_
;
index_t
KRaw_
;
};
};
// Invoker
// Invoker
struct
Invoker
:
public
BaseInvoker
struct
Invoker
:
public
BaseInvoker
{
{
...
@@ -492,7 +493,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
...
@@ -492,7 +493,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
);
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
);
}
}
};
};
#endif
static
constexpr
bool
IsSupported
(
index_t
MRaw_
,
index_t
NRaw_
,
index_t
KRaw_
)
static
constexpr
bool
IsSupported
(
index_t
MRaw_
,
index_t
NRaw_
,
index_t
KRaw_
)
{
{
// check vector load/store
// check vector load/store
...
@@ -574,10 +575,12 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
...
@@ -574,10 +575,12 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
}
}
return
true
;
return
true
;
}
}
#ifndef __HIPCC_RTC__
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
{
if
(
!
ck
::
is_xdl_supported
())
if
(
!
(
ck
::
get_device_name
()
==
"gfx908"
||
ck
::
get_device_name
()
==
"gfx90a"
||
ck
::
get_device_name
()
==
"gfx940"
||
ck
::
get_device_name
()
==
"gfx941"
||
ck
::
get_device_name
()
==
"gfx942"
))
{
{
return
false
;
return
false
;
}
}
...
@@ -595,17 +598,16 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
...
@@ -595,17 +598,16 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
{
{
return
IsSupportedArgument
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
return
IsSupportedArgument
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
}
static
auto
MakeArgument
(
const
void
*
p_a
,
static
auto
MakeArgument
(
const
void
*
p_a
,
const
void
*
p_b
,
const
void
*
p_b
,
std
::
a
rray
<
const
void
*
,
NumDTensor
>
p_ds
,
ck
::
A
rray
<
const
void
*
,
NumDTensor
>
p_ds
,
void
*
p_e
,
void
*
p_e
,
index_t
MRaw
,
index_t
MRaw
,
index_t
NRaw
,
index_t
NRaw
,
index_t
KRaw
,
index_t
KRaw
,
index_t
StrideA
,
index_t
StrideA
,
index_t
StrideB
,
index_t
StrideB
,
std
::
a
rray
<
index_t
,
NumDTensor
>
StrideDs
,
ck
::
A
rray
<
index_t
,
NumDTensor
>
StrideDs
,
index_t
StrideE
,
index_t
StrideE
,
AElementwiseOperation
a_element_op
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
BElementwiseOperation
b_element_op
,
...
@@ -633,14 +635,14 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
...
@@ -633,14 +635,14 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
std
::
unique_ptr
<
BaseArgument
>
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
const
void
*
p_b
,
std
::
a
rray
<
const
void
*
,
NumDTensor
>
p_ds
,
ck
::
A
rray
<
const
void
*
,
NumDTensor
>
p_ds
,
void
*
p_e
,
void
*
p_e
,
index_t
MRaw
,
index_t
MRaw
,
index_t
NRaw
,
index_t
NRaw
,
index_t
KRaw
,
index_t
KRaw
,
index_t
StrideA
,
index_t
StrideA
,
index_t
StrideB
,
index_t
StrideB
,
std
::
a
rray
<
ck
::
index_t
,
NumDTensor
>
StrideDs
,
ck
::
A
rray
<
ck
::
index_t
,
NumDTensor
>
StrideDs
,
index_t
StrideE
,
index_t
StrideE
,
AElementwiseOperation
a_element_op
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
BElementwiseOperation
b_element_op
,
...
@@ -673,11 +675,13 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
...
@@ -673,11 +675,13 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
{
{
auto
str
=
std
::
stringstream
();
auto
str
=
std
::
stringstream
();
std
::
map
<
LoopScheduler
,
std
::
string
>
LoopSchedToString
{
std
::
map
<
LoopScheduler
,
std
::
string
>
LoopSchedToString
{{
LoopScheduler
::
Default
,
"Default"
},
{
LoopScheduler
::
Default
,
"Default"
},
{
LoopScheduler
::
Interwave
,
"Interwave"
}};
{
LoopScheduler
::
Interwave
,
"Interwave"
}};
std
::
map
<
PipelineVersion
,
std
::
string
>
PipelineVersionToString
{{
PipelineVersion
::
v1
,
"v1"
},
std
::
map
<
PipelineVersion
,
std
::
string
>
PipelineVersionToString
{{
PipelineVersion
::
v1
,
"v1"
},
{
PipelineVersion
::
v2
,
"v2"
}};
{
PipelineVersion
::
v2
,
"v2"
}};
// clang-format off
// clang-format off
str
<<
"DeviceGemmMultipleD_Xdl_CShuffle"
str
<<
"DeviceGemmMultipleD_Xdl_CShuffle"
...
@@ -706,6 +710,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
...
@@ -706,6 +710,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
return
str
.
str
();
return
str
.
str
();
}
}
#endif
template
<
class
ADesc
,
class
BDesc
,
class
DsDesc
,
class
EDesc
>
template
<
class
ADesc
,
class
BDesc
,
class
DsDesc
,
class
EDesc
>
struct
Descriptor
struct
Descriptor
...
@@ -722,10 +727,11 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
...
@@ -722,10 +727,11 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
using
BGridDesc_BK0_N_BK1
=
using
BGridDesc_BK0_N_BK1
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDefaultBGridDescriptor_BK0_N_BK1
(
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDefaultBGridDescriptor_BK0_N_BK1
(
DeviceOp
::
matrix_padder
.
PadBDescriptor_N_K
(
BDesc
{})))
>
;
DeviceOp
::
matrix_padder
.
PadBDescriptor_N_K
(
BDesc
{})))
>
;
using
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
using
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
GridwiseGemm
::
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
ds_tuple
()))
>
;
decltype
(
GridwiseGemm
::
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
using
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
ds_tuple
()))
>
;
GridwiseGemm
::
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
using
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
DeviceOp
::
matrix_padder
.
PadCDescriptor_M_N
(
EDesc
{})))
>
;
DeviceOp
::
matrix_padder
.
PadCDescriptor_M_N
(
EDesc
{})))
>
;
using
Block2ETileMap
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDefaultBlock2ETileMap
(
using
Block2ETileMap
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDefaultBlock2ETileMap
(
DeviceOp
::
matrix_padder
.
PadCDescriptor_M_N
(
EDesc
{})))
>
;
DeviceOp
::
matrix_padder
.
PadCDescriptor_M_N
(
EDesc
{})))
>
;
...
@@ -735,7 +741,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
...
@@ -735,7 +741,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock
;
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock
;
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock
;
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock
;
Block2ETileMap
block_2_etile_map
;
Block2ETileMap
block_2_etile_map
;
// element-wise op
// element-wise op
AElementwiseOperation
a_element_op
;
AElementwiseOperation
a_element_op
;
BElementwiseOperation
b_element_op
;
BElementwiseOperation
b_element_op
;
...
@@ -786,10 +792,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
...
@@ -786,10 +792,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
{
{
}
}
constexpr
bool
IsValid
()
const
constexpr
bool
IsValid
()
const
{
return
is_valid
;
}
{
return
is_valid
;
}
};
};
template
<
class
ADesc
,
class
BDesc
,
class
DsDesc
,
class
EDesc
>
template
<
class
ADesc
,
class
BDesc
,
class
DsDesc
,
class
EDesc
>
...
@@ -814,7 +817,9 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
...
@@ -814,7 +817,9 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
EDataType
*
__restrict__
p_e_grid
)
EDataType
*
__restrict__
p_e_grid
)
{
{
__shared__
char
p_shared_block
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared_block
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
#ifndef __HIPCC_RTC__
assert
(
desc
.
is_valid
);
assert
(
desc
.
is_valid
);
#endif
if
(
desc
.
has_main_k_block_loop
)
if
(
desc
.
has_main_k_block_loop
)
{
{
GridwiseGemm
::
template
Run
<
true
>(
p_a_grid
,
GridwiseGemm
::
template
Run
<
true
>(
p_a_grid
,
...
...
include/ck/tensor_operation/gpu/device/masking_specialization.hpp
View file @
4b0b327b
...
@@ -13,6 +13,7 @@ enum struct MaskingSpecialization
...
@@ -13,6 +13,7 @@ enum struct MaskingSpecialization
MaskOutUpperTriangle
MaskOutUpperTriangle
};
};
#ifndef __HIPCC_RTC__
inline
std
::
string
getMaskingSpecializationString
(
const
MaskingSpecialization
&
s
)
inline
std
::
string
getMaskingSpecializationString
(
const
MaskingSpecialization
&
s
)
{
{
switch
(
s
)
switch
(
s
)
...
@@ -22,6 +23,7 @@ inline std::string getMaskingSpecializationString(const MaskingSpecialization& s
...
@@ -22,6 +23,7 @@ inline std::string getMaskingSpecializationString(const MaskingSpecialization& s
default:
return
"Unrecognized specialization!"
;
default:
return
"Unrecognized specialization!"
;
}
}
}
}
#endif
struct
MaskDisabledPredicate
struct
MaskDisabledPredicate
{
{
...
...
include/ck/tensor_operation/gpu/device/tensor_layout.hpp
View file @
4b0b327b
...
@@ -406,7 +406,7 @@ struct G_NDHW : public BaseTensorLayout
...
@@ -406,7 +406,7 @@ struct G_NDHW : public BaseTensorLayout
template
<
template
<
typename
Layout
,
typename
Layout
,
typename
std
::
enable_if
<
std
::
is_base_of
<
BaseTensorLayout
,
Layout
>
::
value
,
bool
>::
type
=
false
>
typename
ck
::
enable_if
<
ck
::
is_base_of
<
BaseTensorLayout
,
Layout
>
::
value
,
bool
>::
type
=
false
>
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
Layout
&
)
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
Layout
&
)
{
{
os
<<
Layout
::
name
;
os
<<
Layout
::
name
;
...
...
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
View file @
4b0b327b
...
@@ -288,6 +288,7 @@ struct FastGelu
...
@@ -288,6 +288,7 @@ struct FastGelu
template
<
typename
Y
,
typename
X
>
template
<
typename
Y
,
typename
X
>
__device__
void
operator
()(
Y
&
y
,
const
X
&
x
)
const
;
__device__
void
operator
()(
Y
&
y
,
const
X
&
x
)
const
;
#ifndef __HIPCC_RTC__
template
<
>
template
<
>
__host__
void
operator
()
<
float
,
float
>
(
float
&
y
,
const
float
&
x
)
const
__host__
void
operator
()
<
float
,
float
>
(
float
&
y
,
const
float
&
x
)
const
{
{
...
@@ -297,7 +298,7 @@ struct FastGelu
...
@@ -297,7 +298,7 @@ struct FastGelu
y
=
x
*
cdf
;
y
=
x
*
cdf
;
}
}
#endif
// device code, use lower precision "__expf" and "rcp"
// device code, use lower precision "__expf" and "rcp"
template
<
>
template
<
>
__device__
void
operator
()
<
float
,
float
>
(
float
&
y
,
const
float
&
x
)
const
__device__
void
operator
()
<
float
,
float
>
(
float
&
y
,
const
float
&
x
)
const
...
...
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
View file @
4b0b327b
...
@@ -5,10 +5,13 @@
...
@@ -5,10 +5,13 @@
#include "ck/utility/math.hpp"
#include "ck/utility/math.hpp"
#include "ck/utility/number.hpp"
#include "ck/utility/number.hpp"
#include "ck/utility/tuple.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp"
#ifndef __HIPCC_RTC__
#include <limits>
#include <limits>
#include <stdlib.h>
#include <stdlib.h>
#endif
namespace
ck
{
namespace
ck
{
...
@@ -86,16 +89,16 @@ struct BlockToCTileMap_M00_N0_M01
...
@@ -86,16 +89,16 @@ struct BlockToCTileMap_M00_N0_M01
const
auto
M00
=
math
::
integer_divide_ceil
(
M0
,
M01
);
const
auto
M00
=
math
::
integer_divide_ceil
(
M0
,
M01
);
const
auto
m00_n0_m01_to_m0_n0_block_cluster_adaptor
=
make_single_stage_tensor_adaptor
(
const
auto
m00_n0_m01_to_m0_n0_block_cluster_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_insert_transform
(
1
),
ck
::
make_tuple
(
make_insert_transform
(
1
),
make_unmerge_transform
(
make_tuple
(
M00
,
M01
)),
make_unmerge_transform
(
ck
::
make_tuple
(
M00
,
M01
)),
make_pass_through_transform
(
make_tuple
(
N0
))),
make_pass_through_transform
(
ck
::
make_tuple
(
N0
))),
make_tuple
(
Sequence
<>
{},
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
ck
::
make_tuple
(
Sequence
<>
{},
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
3
>
{},
Sequence
<
2
>
{}));
ck
::
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
3
>
{},
Sequence
<
2
>
{}));
const
auto
cblockid_to_m00_n0_m01_block_cluster_adaptor
=
make_single_stage_tensor_adaptor
(
const
auto
cblockid_to_m00_n0_m01_block_cluster_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
1
,
M00
,
N0
,
M01
))),
ck
::
make_tuple
(
make_merge_transform
(
ck
::
make_tuple
(
1
,
M00
,
N0
,
M01
))),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
>
{}),
ck
::
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
ck
::
make_tuple
(
Sequence
<
0
>
{}));
const
auto
cblockid_to_m0_n0_block_cluster_adaptor
=
const
auto
cblockid_to_m0_n0_block_cluster_adaptor
=
chain_tensor_adaptors
(
m00_n0_m01_to_m0_n0_block_cluster_adaptor
,
chain_tensor_adaptors
(
m00_n0_m01_to_m0_n0_block_cluster_adaptor
,
...
@@ -231,8 +234,8 @@ struct BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, void>
...
@@ -231,8 +234,8 @@ struct BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, void>
* output {1, 2}
* output {1, 2}
*/
*/
return
make_tuple
(
idx_N0_M01_local
%
M01_adapt
+
idx_M00
*
M01_
,
return
ck
::
make_tuple
(
idx_N0_M01_local
%
M01_adapt
+
idx_M00
*
M01_
,
idx_N0_M01_local
/
M01_adapt
);
idx_N0_M01_local
/
M01_adapt
);
}
}
template
<
typename
CTileIdx
,
typename
CTileDim
>
template
<
typename
CTileIdx
,
typename
CTileDim
>
...
@@ -307,9 +310,9 @@ struct BlockToCTileMap_KSplit_M00_N0_M01Adapt
...
@@ -307,9 +310,9 @@ struct BlockToCTileMap_KSplit_M00_N0_M01Adapt
index_t
idx_M01
=
idx_M0
%
M01_
;
index_t
idx_M01
=
idx_M0
%
M01_
;
index_t
idx_N0_M01_local
=
idx_N0
+
idx_M01
*
N0
;
index_t
idx_N0_M01_local
=
idx_N0
+
idx_M01
*
N0
;
return
make_tuple
(
idx_ksplit
,
return
ck
::
make_tuple
(
idx_ksplit
,
idx_N0_M01_local
%
M01_adapt
+
idx_M00
*
M01_
,
idx_N0_M01_local
%
M01_adapt
+
idx_M00
*
M01_
,
idx_N0_M01_local
/
M01_adapt
);
idx_N0_M01_local
/
M01_adapt
);
}
}
template
<
typename
CTileIdx
,
typename
CTileDim
>
template
<
typename
CTileIdx
,
typename
CTileDim
>
...
@@ -406,17 +409,17 @@ struct BlockToCTileMap_M00_N00_M01_N01
...
@@ -406,17 +409,17 @@ struct BlockToCTileMap_M00_N00_M01_N01
const
auto
m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor
=
const
auto
m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor
=
make_single_stage_tensor_adaptor
(
make_single_stage_tensor_adaptor
(
make_tuple
(
make_insert_transform
(
1
),
// swallow the carry from lower dimensions
ck
::
make_tuple
(
make_insert_transform
(
1
),
// swallow the carry from lower dimensions
make_unmerge_transform
(
make_tuple
(
M00
,
M01
)),
make_unmerge_transform
(
ck
::
make_tuple
(
M00
,
M01
)),
make_unmerge_transform
(
make_tuple
(
N00
,
N01
))),
make_unmerge_transform
(
ck
::
make_tuple
(
N00
,
N01
))),
make_tuple
(
Sequence
<>
{},
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
ck
::
make_tuple
(
Sequence
<>
{},
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
3
>
{},
Sequence
<
2
,
4
>
{}));
ck
::
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
3
>
{},
Sequence
<
2
,
4
>
{}));
const
auto
cblockid_to_m00_m01_n00_n01_block_cluster_adaptor
=
const
auto
cblockid_to_m00_m01_n00_n01_block_cluster_adaptor
=
make_single_stage_tensor_adaptor
(
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
1
,
M00
,
N00
,
M01
,
N01
))),
ck
::
make_tuple
(
make_merge_transform
(
ck
::
make_tuple
(
1
,
M00
,
N00
,
M01
,
N01
))),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
,
4
>
{}),
ck
::
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
ck
::
make_tuple
(
Sequence
<
0
>
{}));
const
auto
cblockid_to_m0_n0_block_cluster_adaptor
=
const
auto
cblockid_to_m0_n0_block_cluster_adaptor
=
chain_tensor_adaptors
(
m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor
,
chain_tensor_adaptors
(
m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor
,
...
@@ -525,17 +528,17 @@ struct BlockToCTileMap_KSplit_M00_N00_M01_N01
...
@@ -525,17 +528,17 @@ struct BlockToCTileMap_KSplit_M00_N00_M01_N01
const
auto
ksplit_m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor
=
const
auto
ksplit_m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor
=
make_single_stage_tensor_adaptor
(
make_single_stage_tensor_adaptor
(
make_tuple
(
make_pass_through_transform
(
KSplit
),
ck
::
make_tuple
(
make_pass_through_transform
(
KSplit
),
make_unmerge_transform
(
make_tuple
(
M00
,
M01
)),
make_unmerge_transform
(
ck
::
make_tuple
(
M00
,
M01
)),
make_unmerge_transform
(
make_tuple
(
N00
,
N01
))),
make_unmerge_transform
(
ck
::
make_tuple
(
N00
,
N01
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
ck
::
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
3
>
{},
Sequence
<
2
,
4
>
{}));
ck
::
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
3
>
{},
Sequence
<
2
,
4
>
{}));
const
auto
c_blockid_to_ksplit_m00_m01_n00_n01_block_cluster_adaptor
=
const
auto
c_blockid_to_ksplit_m00_m01_n00_n01_block_cluster_adaptor
=
make_single_stage_tensor_adaptor
(
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
KSplit
,
M00
,
N00
,
M01
,
N01
))),
ck
::
make_tuple
(
make_merge_transform
(
ck
::
make_tuple
(
KSplit
,
M00
,
N00
,
M01
,
N01
))),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
,
4
>
{}),
ck
::
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
ck
::
make_tuple
(
Sequence
<
0
>
{}));
const
auto
c_blockid_to_ksplit_m0_n0_block_cluster_adaptor
=
const
auto
c_blockid_to_ksplit_m0_n0_block_cluster_adaptor
=
chain_tensor_adaptors
(
ksplit_m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor
,
chain_tensor_adaptors
(
ksplit_m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor
,
...
@@ -652,13 +655,13 @@ struct BlockToCTileMap_3DGrid_KSplit
...
@@ -652,13 +655,13 @@ struct BlockToCTileMap_3DGrid_KSplit
const
auto
M0
=
math
::
integer_divide_ceil
(
M
,
MPerBlock
);
const
auto
M0
=
math
::
integer_divide_ceil
(
M
,
MPerBlock
);
const
auto
N0
=
math
::
integer_divide_ceil
(
N
,
NPerBlock
);
const
auto
N0
=
math
::
integer_divide_ceil
(
N
,
NPerBlock
);
return
std
::
make_tuple
(
N0
,
M0
,
k_split
);
return
ck
::
make_tuple
(
N0
,
M0
,
k_split
);
}
}
template
<
typename
TopIdx
>
template
<
typename
TopIdx
>
__device__
constexpr
auto
CalculateBottomIndex
(
const
TopIdx
&
)
const
__device__
constexpr
auto
CalculateBottomIndex
(
const
TopIdx
&
)
const
{
{
return
make_tuple
(
blockIdx
.
z
,
blockIdx
.
y
,
blockIdx
.
x
);
return
ck
::
make_tuple
(
blockIdx
.
z
,
blockIdx
.
y
,
blockIdx
.
x
);
}
}
template
<
typename
CTileIdx
,
typename
CTileDim
>
template
<
typename
CTileIdx
,
typename
CTileDim
>
...
@@ -776,7 +779,7 @@ struct BlockToCTileMap_GemmStreamK
...
@@ -776,7 +779,7 @@ struct BlockToCTileMap_GemmStreamK
uint32_t
dp_for_sk_iters
=
k_iters_per_tile
.
get
();
uint32_t
dp_for_sk_iters
=
k_iters_per_tile
.
get
();
uint32_t
best_sk_score
=
uint32_t
best_sk_score
=
std
::
n
umeric
_l
imits
<
int
>::
m
ax
();
// we need to find the smallest sk iters
ck
::
N
umeric
L
imits
<
int
32_t
>::
M
ax
();
// we need to find the smallest sk iters
for
(
uint32_t
tentative_sk_blocks
=
min_sk_tiles
;
tentative_sk_blocks
<
max_sk_tiles
;
for
(
uint32_t
tentative_sk_blocks
=
min_sk_tiles
;
tentative_sk_blocks
<
max_sk_tiles
;
tentative_sk_blocks
++
)
tentative_sk_blocks
++
)
{
{
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp
View file @
4b0b327b
...
@@ -3,11 +3,11 @@
...
@@ -3,11 +3,11 @@
#pragma once
#pragma once
#include <iostream>
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp"
#ifndef __HIPCC_RTC__
#include <iostream>
#include <iostream>
#endif
namespace
ck
{
namespace
ck
{
...
@@ -39,7 +39,9 @@ constexpr auto GridwiseGemmPipeline_Selector()
...
@@ -39,7 +39,9 @@ constexpr auto GridwiseGemmPipeline_Selector()
}
}
else
else
{
{
#ifndef __HIPCC_RTC__
std
::
cerr
<<
"GridwiseGemmPipeline configuration is not available"
<<
std
::
endl
;
std
::
cerr
<<
"GridwiseGemmPipeline configuration is not available"
<<
std
::
endl
;
#endif
}
}
}
}
...
...
include/ck/utility/amd_wave_read_first_lane.hpp
View file @
4b0b327b
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -7,10 +7,12 @@
...
@@ -7,10 +7,12 @@
#include "ck/utility/functional2.hpp"
#include "ck/utility/functional2.hpp"
#include "ck/utility/math.hpp"
#include "ck/utility/math.hpp"
#ifndef __HIPCC_RTC__
#include <array>
#include <array>
#include <cstddef>
#include <cstddef>
#include <cstdint>
#include <cstdint>
#include <type_traits>
#include <type_traits>
#endif
namespace
ck
{
namespace
ck
{
namespace
detail
{
namespace
detail
{
...
@@ -37,7 +39,7 @@ struct get_carrier<3>
...
@@ -37,7 +39,7 @@ struct get_carrier<3>
{
{
using
value_type
=
uint32_t
;
using
value_type
=
uint32_t
;
std
::
array
<
std
::
byte
,
3
>
bytes
;
ck
::
byte
bytes
[
3
]
;
static_assert
(
sizeof
(
bytes
)
<=
sizeof
(
value_type
));
static_assert
(
sizeof
(
bytes
)
<=
sizeof
(
value_type
));
// replacement of host std::copy_n()
// replacement of host std::copy_n()
...
@@ -59,24 +61,21 @@ struct get_carrier<3>
...
@@ -59,24 +61,21 @@ struct get_carrier<3>
}
}
// method to trigger template substitution failure
// method to trigger template substitution failure
__device__
carrier
(
const
carrier
&
other
)
noexcept
__device__
carrier
(
const
carrier
&
other
)
noexcept
{
copy_n
(
&
other
.
bytes
[
0
],
3
,
&
bytes
[
0
]);
}
{
copy_n
(
other
.
bytes
.
begin
(),
bytes
.
size
(),
bytes
.
begin
());
}
public:
public:
__device__
carrier
&
operator
=
(
value_type
value
)
noexcept
__device__
carrier
&
operator
=
(
value_type
value
)
noexcept
{
{
copy_n
(
reinterpret_cast
<
const
std
::
byte
*>
(
&
value
),
bytes
.
size
()
,
bytes
.
begin
()
);
copy_n
(
reinterpret_cast
<
const
ck
::
byte
*>
(
&
value
),
3
,
&
bytes
[
0
]
);
return
*
this
;
return
*
this
;
}
}
__device__
operator
value_type
()
const
noexcept
__device__
operator
value_type
()
const
noexcept
{
{
std
::
byte
result
[
sizeof
(
value_type
)];
ck
::
byte
result
[
sizeof
(
value_type
)];
copy_n
(
bytes
.
begin
(),
bytes
.
size
()
,
result
);
copy_n
(
&
bytes
[
0
],
3
,
result
);
return
*
reinterpret_cast
<
const
value_type
*>
(
result
);
return
*
reinterpret_cast
<
const
value_type
*>
(
result
);
}
}
...
@@ -100,17 +99,17 @@ __device__ inline int32_t amd_wave_read_first_lane(int32_t value)
...
@@ -100,17 +99,17 @@ __device__ inline int32_t amd_wave_read_first_lane(int32_t value)
return
__builtin_amdgcn_readfirstlane
(
value
);
return
__builtin_amdgcn_readfirstlane
(
value
);
}
}
template
<
template
<
typename
Object
,
typename
Object
,
typename
=
ck
::
enable_if_t
<
ck
::
is_class
<
Object
>
::
value
&&
typename
=
std
::
enable_if_t
<
std
::
is_class_v
<
Object
>
&&
std
::
is_trivially_copyable
_v
<
Object
>>>
ck
::
is_trivially_copyable
<
Object
>
::
value
>>
__device__
auto
amd_wave_read_first_lane
(
const
Object
&
obj
)
__device__
auto
amd_wave_read_first_lane
(
const
Object
&
obj
)
{
{
using
Size
=
unsigned
;
using
Size
=
unsigned
;
constexpr
Size
SgprSize
=
4
;
constexpr
Size
SgprSize
=
4
;
constexpr
Size
ObjectSize
=
sizeof
(
Object
);
constexpr
Size
ObjectSize
=
sizeof
(
Object
);
auto
*
const
from_obj
=
reinterpret_cast
<
const
std
::
byte
*>
(
&
obj
);
auto
*
const
from_obj
=
reinterpret_cast
<
const
ck
::
byte
*>
(
&
obj
);
alignas
(
Object
)
std
::
byte
to_obj
[
ObjectSize
];
alignas
(
Object
)
ck
::
byte
to_obj
[
ObjectSize
];
constexpr
Size
RemainedSize
=
ObjectSize
%
SgprSize
;
constexpr
Size
RemainedSize
=
ObjectSize
%
SgprSize
;
constexpr
Size
CompleteSgprCopyBoundary
=
ObjectSize
-
RemainedSize
;
constexpr
Size
CompleteSgprCopyBoundary
=
ObjectSize
-
RemainedSize
;
...
...
include/ck/utility/array.hpp
View file @
4b0b327b
...
@@ -52,7 +52,7 @@ template <typename X, typename... Xs>
...
@@ -52,7 +52,7 @@ template <typename X, typename... Xs>
__host__
__device__
constexpr
auto
make_array
(
X
&&
x
,
Xs
&&
...
xs
)
__host__
__device__
constexpr
auto
make_array
(
X
&&
x
,
Xs
&&
...
xs
)
{
{
using
data_type
=
remove_cvref_t
<
X
>
;
using
data_type
=
remove_cvref_t
<
X
>
;
return
Array
<
data_type
,
sizeof
...(
Xs
)
+
1
>
{
std
::
forward
<
X
>
(
x
),
std
::
forward
<
Xs
>
(
xs
)...};
return
Array
<
data_type
,
sizeof
...(
Xs
)
+
1
>
{
ck
::
forward
<
X
>
(
x
),
ck
::
forward
<
Xs
>
(
xs
)...};
}
}
// make empty array
// make empty array
...
...
include/ck/utility/container_helper.hpp
View file @
4b0b327b
...
@@ -326,14 +326,14 @@ template <typename T, index_t NX, index_t NY>
...
@@ -326,14 +326,14 @@ template <typename T, index_t NX, index_t NY>
__host__
__device__
constexpr
auto
container_concat
(
const
Array
<
T
,
NX
>&
ax
,
const
Array
<
T
,
NY
>&
ay
)
__host__
__device__
constexpr
auto
container_concat
(
const
Array
<
T
,
NX
>&
ax
,
const
Array
<
T
,
NY
>&
ay
)
{
{
return
unpack2
(
return
unpack2
(
[
&
](
auto
&&
...
zs
)
{
return
make_array
(
std
::
forward
<
decltype
(
zs
)
>
(
zs
)...);
},
ax
,
ay
);
[
&
](
auto
&&
...
zs
)
{
return
make_array
(
ck
::
forward
<
decltype
(
zs
)
>
(
zs
)...);
},
ax
,
ay
);
}
}
template
<
typename
...
X
,
typename
...
Y
>
template
<
typename
...
X
,
typename
...
Y
>
__host__
__device__
constexpr
auto
container_concat
(
const
Tuple
<
X
...
>&
tx
,
const
Tuple
<
Y
...
>&
ty
)
__host__
__device__
constexpr
auto
container_concat
(
const
Tuple
<
X
...
>&
tx
,
const
Tuple
<
Y
...
>&
ty
)
{
{
return
unpack2
(
return
unpack2
(
[
&
](
auto
&&
...
zs
)
{
return
make_tuple
(
std
::
forward
<
decltype
(
zs
)
>
(
zs
)...);
},
tx
,
ty
);
[
&
](
auto
&&
...
zs
)
{
return
make_tuple
(
ck
::
forward
<
decltype
(
zs
)
>
(
zs
)...);
},
tx
,
ty
);
}
}
template
<
typename
Container
>
template
<
typename
Container
>
...
...
include/ck/utility/data_type.hpp
View file @
4b0b327b
...
@@ -5,7 +5,22 @@
...
@@ -5,7 +5,22 @@
#include "ck/utility/statically_indexed_array.hpp"
#include "ck/utility/statically_indexed_array.hpp"
#ifdef __HIPCC_RTC__
/// Definitions from <cstdint>, <cmath> conflict with
/// /opt/rocm/include/hip/amd_detail/amd_hip_vector_types.h.
using
int8_t
=
signed
char
;
using
uint8_t
=
unsigned
char
;
using
int16_t
=
signed
short
;
using
uint16_t
=
unsigned
short
;
using
float_t
=
float
;
#endif // __HIPCC_RTC__
namespace
ck
{
namespace
ck
{
#ifdef __HIPCC_RTC__
using
byte
=
unsigned
char
;
#else
using
std
::
byte
;
#endif
using
bhalf_t
=
ushort
;
using
bhalf_t
=
ushort
;
using
half_t
=
_Float16
;
using
half_t
=
_Float16
;
...
@@ -961,20 +976,96 @@ using f8x32_t = typename vector_type<f8_t, 32>::type;
...
@@ -961,20 +976,96 @@ using f8x32_t = typename vector_type<f8_t, 32>::type;
using
f8x64_t
=
typename
vector_type
<
f8_t
,
64
>::
type
;
using
f8x64_t
=
typename
vector_type
<
f8_t
,
64
>::
type
;
template
<
typename
T
>
template
<
typename
T
>
struct
NumericLimits
struct
NumericLimits
;
template
<
>
struct
NumericLimits
<
int32_t
>
{
{
__host__
__device__
static
constexpr
T
Min
()
{
return
std
::
numeric_limits
<
T
>::
min
()
;
}
__host__
__device__
static
constexpr
int32_t
Lowest
()
noexcept
{
return
-
2147483647
-
1
;
}
__host__
__device__
static
constexpr
T
Max
()
{
return
std
::
numeric_limits
<
T
>::
max
()
;
}
__host__
__device__
static
constexpr
int32_t
Min
()
noexcept
{
return
-
2147483647
-
1
;
}
__host__
__device__
static
constexpr
T
Lowest
()
{
return
std
::
numeric_limits
<
T
>::
lowest
()
;
}
__host__
__device__
static
constexpr
int32_t
Max
()
noexcept
{
return
2147483647
;
}
__host__
__device__
static
constexpr
T
QuietNaN
()
__host__
__device__
static
constexpr
int32_t
Infinity
()
noexcept
{
return
0
;
}
{
return
std
::
numeric_limits
<
T
>::
quiet_NaN
();
__host__
__device__
static
constexpr
int32_t
QuietNaN
()
{
return
0
;
}
}
};
template
<
>
struct
NumericLimits
<
int16_t
>
{
__host__
__device__
static
constexpr
int16_t
Lowest
()
noexcept
{
return
-
32768
;
}
__host__
__device__
static
constexpr
int16_t
Min
()
noexcept
{
return
-
32768
;
}
__host__
__device__
static
constexpr
int16_t
Max
()
noexcept
{
return
32767
;
}
__host__
__device__
static
constexpr
int16_t
Infinity
()
noexcept
{
return
0
;
}
__host__
__device__
static
constexpr
int16_t
QuietNaN
()
{
return
0
;
}
};
template
<
>
struct
NumericLimits
<
int8_t
>
{
__host__
__device__
static
constexpr
int8_t
Lowest
()
noexcept
{
return
-
128
;
}
__host__
__device__
static
constexpr
int8_t
Min
()
noexcept
{
return
-
128
;
}
__host__
__device__
static
constexpr
int8_t
Max
()
noexcept
{
return
127
;
}
__host__
__device__
static
constexpr
int8_t
Infinity
()
noexcept
{
return
0
;
}
__host__
__device__
static
constexpr
int8_t
QuietNaN
()
{
return
0
;
}
};
template
<
>
struct
NumericLimits
<
uint32_t
>
{
__host__
__device__
static
constexpr
uint32_t
Lowest
()
noexcept
{
return
0
;
}
__host__
__device__
static
constexpr
uint32_t
Min
()
noexcept
{
return
0
;
}
__host__
__device__
static
constexpr
uint32_t
Max
()
noexcept
{
return
4294967295U
;
}
__host__
__device__
static
constexpr
uint32_t
Infinity
()
noexcept
{
return
0
;
}
__host__
__device__
static
constexpr
uint32_t
QuietNaN
()
{
return
0
;
}
};
template
<
>
struct
NumericLimits
<
uint16_t
>
{
__host__
__device__
static
constexpr
uint16_t
Lowest
()
noexcept
{
return
0
;
}
__host__
__device__
static
constexpr
uint16_t
Min
()
noexcept
{
return
0
;
}
__host__
__device__
static
constexpr
uint16_t
Max
()
noexcept
{
return
65535U
;
}
__host__
__device__
static
constexpr
uint16_t
Infinity
()
noexcept
{
return
0
;
}
__host__
__device__
static
constexpr
uint16_t
QuietNaN
()
{
return
0
;
}
};
template
<
>
struct
NumericLimits
<
float
>
{
static
constexpr
unsigned
int
binary_min
=
0x00800000
;
static
constexpr
unsigned
int
binary_max
=
0x7F7FFFFF
;
static
constexpr
unsigned
int
binary_lowest
=
0xFF7FFFFF
;
static
constexpr
unsigned
int
binary_qnan
=
0xFFC00001
;
static
constexpr
unsigned
int
binary_inf
=
0x7F8000000
;
__host__
__device__
static
constexpr
float
Min
()
{
return
bit_cast
<
float
>
(
binary_min
);
}
__host__
__device__
static
constexpr
float
Max
()
{
return
bit_cast
<
float
>
(
binary_max
);
}
__host__
__device__
static
constexpr
float
Lowest
()
{
return
bit_cast
<
float
>
(
binary_lowest
);
}
__host__
__device__
static
constexpr
float
QuietNaN
()
{
return
bit_cast
<
float
>
(
binary_qnan
);
}
__host__
__device__
static
constexpr
T
Infinity
()
{
return
std
::
numeric_limits
<
T
>::
infinity
(
);
}
__host__
__device__
static
constexpr
float
Infinity
()
{
return
bit_cast
<
float
>
(
binary_inf
);
}
};
};
template
<
>
template
<
>
...
...
include/ck/utility/debug.hpp
View file @
4b0b327b
...
@@ -3,7 +3,7 @@
...
@@ -3,7 +3,7 @@
#ifndef UTILITY_DEBUG_HPP
#ifndef UTILITY_DEBUG_HPP
#define UTILITY_DEBUG_HPP
#define UTILITY_DEBUG_HPP
#include "type.hpp"
namespace
ck
{
namespace
ck
{
namespace
debug
{
namespace
debug
{
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment