Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
d46c7224
Commit
d46c7224
authored
May 24, 2023
by
Alan Turner
Browse files
Update to use new API
parent
15bf2de8
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
44 additions
and
354 deletions
+44
-354
src/targets/gpu/CMakeLists.txt
src/targets/gpu/CMakeLists.txt
+2
-5
src/targets/gpu/compile_hip_code_object.cpp
src/targets/gpu/compile_hip_code_object.cpp
+13
-0
src/targets/gpu/include/migraphx/gpu/compile_hip_code_object.hpp
...gets/gpu/include/migraphx/gpu/compile_hip_code_object.hpp
+2
-0
src/targets/gpu/jit/ck_gemm.cpp
src/targets/gpu/jit/ck_gemm.cpp
+27
-87
src/targets/gpu/kernels/include/migraphx/kernels/ck_gemm.hpp
src/targets/gpu/kernels/include/migraphx/kernels/ck_gemm.hpp
+0
-1
src/targets/gpu/kernels/include/migraphx/kernels/ck_gemm_includes.hpp
...gpu/kernels/include/migraphx/kernels/ck_gemm_includes.hpp
+0
-261
No files found.
src/targets/gpu/CMakeLists.txt
View file @
d46c7224
...
...
@@ -261,15 +261,12 @@ else()
message
(
STATUS
"MIOpen does not have find mode api"
)
endif
()
#find_package(composable_kernel REQUIRED PATHS /code/composable_kernel)
find_package
(
composable_kernel 1.0.0 COMPONENTS device_operations
)
#target_link_libraries(migraphx_gpu PRIVATE composable_kernel::device_operations)
find_package
(
composable_kernel 1.0.0 COMPONENTS jit_library REQUIRED
)
# Workaround broken rocblas headers
target_compile_definitions
(
migraphx_gpu PUBLIC -D__HIP_PLATFORM_HCC__=1
)
target_link_libraries
(
migraphx_gpu PUBLIC migraphx MIOpen roc::rocblas
)
target_link_libraries
(
migraphx_gpu PRIVATE migraphx_device migraphx_kernels composable_kernel::
device_operations
)
target_link_libraries
(
migraphx_gpu PRIVATE migraphx_device migraphx_kernels composable_kernel::
jit_library
)
add_subdirectory
(
driver
)
...
...
src/targets/gpu/compile_hip_code_object.cpp
View file @
d46c7224
...
...
@@ -167,6 +167,19 @@ operation compile_hip_code_object(const std::string& content, hip_compile_option
auto
path
=
fs
::
path
{
"migraphx"
}
/
"kernels"
/
name
;
return
src_file
{
path
,
c
};
});
if
(
not
options
.
embedded_headers
.
empty
())
{
std
::
transform
(
options
.
embedded_headers
.
begin
(),
options
.
embedded_headers
.
end
(),
std
::
back_inserter
(
srcs
),
[](
auto
&&
p
)
{
auto
&&
name
=
p
.
first
;
auto
&&
c
=
p
.
second
;
auto
path
=
fs
::
path
{
"migraphx"
}
/
"kernels"
/
name
;
return
src_file
{
path
,
c
};
});
}
srcs
.
push_back
(
src_file
{
fs
::
path
{
"main.cpp"
},
std
::
make_pair
(
content
.
data
(),
content
.
data
()
+
content
.
size
())});
auto
args_hpp
=
...
...
src/targets/gpu/include/migraphx/gpu/compile_hip_code_object.hpp
View file @
d46c7224
...
...
@@ -42,6 +42,8 @@ struct hip_compile_options
std
::
string
kernel_name
=
"kernel"
;
std
::
string
params
=
""
;
std
::
vector
<
shape
>
virtual_inputs
=
{};
std
::
unordered_map
<
std
::
string
,
std
::
pair
<
const
char
*
,
const
char
*>>
embedded_headers
;
/**
* @brief Set the launch parameters but allow v to override the values
...
...
src/targets/gpu/jit/ck_gemm.cpp
View file @
d46c7224
...
...
@@ -38,18 +38,8 @@
#include <migraphx/env.hpp>
#include <migraphx/file_buffer.hpp>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/include/device_gemm_multiple_d.hpp"
#include "ck/library/tensor_operation_instance/gpu/gemm_add_add_fastgelu.hpp"
#include "ck/library/tensor_operation_instance/solution_instances/gemm_multiple_d_xdlop_cshuffle.hpp"
#include <iostream>
const
std
::
vector
<
std
::
string
>&
get_instance
(
std
::
size_t
i
,
const
std
::
function
<
bool
(
const
std
::
vector
<
std
::
string
>&
)
>&
pred
);
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
@@ -68,7 +58,7 @@ static const char* const ck_gemm_kernel = R"__migraphx__(
#include <args.hpp>
#include <migraphx/kernels/ck_gemm.hpp>
#include <migraphx/kernels/pointwise.hpp>
#include <
ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp
>
#include <
migraphx/kernels/${include}
>
namespace migraphx {
...
...
@@ -89,64 +79,7 @@ __global__ void ${kernel}(${params})
)__migraphx__"
;
static
std
::
size_t
int_div_ceil
(
std
::
size_t
x
,
std
::
size_t
y
)
{
return
(
x
+
y
-
1
)
/
y
;
}
struct
instance
{
std
::
vector
<
std
::
string
>
params
;
static
const
std
::
size_t
block_size_index
=
15
;
std
::
size_t
int_at
(
std
::
size_t
i
)
const
{
return
std
::
stoull
(
params
[
i
]);
}
std
::
size_t
get_block_size
()
const
{
return
int_at
(
block_size_index
);
}
std
::
size_t
get_pb
(
std
::
size_t
i
)
const
{
assert
(
i
<
4
);
return
int_at
(
block_size_index
+
1
+
i
);
}
std
::
array
<
std
::
size_t
,
3
>
get_pad
(
const
std
::
array
<
std
::
size_t
,
3
>&
config
)
const
{
std
::
array
<
std
::
size_t
,
3
>
result
{};
for
(
auto
i
:
range
(
config
.
size
()))
{
result
[
i
]
=
int_div_ceil
(
config
[
i
],
get_pb
(
i
))
*
get_pb
(
i
)
-
config
[
i
];
}
return
result
;
}
std
::
size_t
get_grid_size
(
const
std
::
array
<
std
::
size_t
,
3
>&
config
)
const
{
return
int_div_ceil
(
config
[
0
],
get_pb
(
0
))
*
int_div_ceil
(
config
[
1
],
get_pb
(
1
));
}
void
set_ds_layout
(
const
std
::
string
&
s
)
{
assert
(
params
[
2
]
==
"ck::Tuple<>"
);
params
[
2
]
=
s
;
}
void
set_ds_type
(
const
std
::
string
&
s
)
{
assert
(
params
[
8
]
==
"ck::Tuple<>"
);
params
[
8
]
=
s
;
}
void
set_ds_op
(
const
std
::
string
&
s
)
{
assert
(
params
[
12
]
==
"ck_passthrough"
);
params
[
12
]
=
s
;
}
void
set_gemm
(
const
std
::
string
&
s
)
{
assert
(
params
[
13
]
==
"ck::tensor_operation::device::GemmSpecialization::Default"
);
params
[
13
]
=
s
;
}
std
::
string
str
()
const
{
return
join_strings
(
params
,
","
);
}
};
static
bool
transposed_matrix
(
const
shape
&
s
)
{
return
s
.
strides
().
back
()
!=
1
;
}
...
...
@@ -304,18 +237,18 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
auto
n
=
c_shape
.
lens
().
back
();
auto
k
=
a_shape
.
lens
().
back
();
const
auto
numDTensors
=
inputs
.
size
()
-
3
;
const
bool
transA
=
transposed_matrix
(
a_shape
);
const
bool
transB
=
transposed_matrix
(
b_shape
);
const
bool
trans
CD
E
=
transposed_matrix
(
c_shape
);
const
bool
transE
=
transposed_matrix
(
c_shape
);
const
auto
a_type
=
get_type
(
a_shape
);
const
auto
b_type
=
get_type
(
b_shape
);
const
auto
cde_type
=
ck_tuple
(
inputs
.
begin
()
+
2
,
inputs
.
end
()
-
1
,
&
get_type
);
// get_type(c_shape);
const
auto
cde_layout
=
ck_tuple
(
inputs
.
begin
()
+
2
,
inputs
.
end
()
-
1
,
&
get_layout
);
const
auto
e_type
=
get_type
(
c_shape
);
std
::
vector
<
bool
>
ds_layout
;
std
::
transform
(
inputs
.
begin
()
+
2
,
inputs
.
end
()
-
1
,
std
::
back_inserter
(
ds_layout
),
[](
const
auto
&
i
){
return
transposed_matrix
(
i
);
});
std
::
vector
<
std
::
string
>
ds_type
;
std
::
transform
(
inputs
.
begin
()
+
2
,
inputs
.
end
()
-
1
,
std
::
back_inserter
(
ds_type
),
[](
const
auto
&
i
){
return
get_type
(
i
);
});
std
::
string
ck_passthrough
=
"ck_passthrough"
;
//"ck::tensor_operation::element_wise::PassThrough";
std
::
string
ck_passthrough
=
"ck_passthrough"
;
std
::
string
cde_op
=
ck_passthrough
;
assert
(
inputs
.
size
()
<
4
or
v
.
contains
(
"post"
));
if
(
v
.
contains
(
"post"
))
...
...
@@ -324,27 +257,33 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
}
auto
problem
=
ck
::
tensor_operation
::
device
::
instance
::
Problem
{
static_cast
<
ck
::
index_t
>
(
m
),
ck
::
tensor_operation
::
device
::
device_gemm_multiple_d
::
Problem
{
static_cast
<
ck
::
index_t
>
(
m
),
static_cast
<
ck
::
index_t
>
(
n
),
static_cast
<
ck
::
index_t
>
(
k
),
static_cast
<
ck
::
index_t
>
(
numDTensors
),
transA
,
transB
,
transCDE
,
transE
,
ds_layout
,
a_type
,
b_type
,
cde_type
,
e_type
,
ds_type
,
ck_passthrough
,
ck_passthrough
,
cde_op
,
cde_layout
};
const
auto
solutions
=
problem
.
GetSolutions
();
cde_op
};
const
auto
include_header
=
problem
.
GetIncludeHeader
();
const
auto
ck_headers
=
problem
.
GetHeaders
();
const
auto
solutions
=
problem
.
GetSolutions
(
"gfx90a"
);
const
auto
solution
=
solutions
.
at
(
tuning_value
);
const
auto
template_str
=
solution
.
GetStr
();
const
auto
blocks_per_batch
=
solution
.
GetGridSize
();
const
auto
block_size
=
solution
.
GetBlockSize
();
const
auto
template_str
=
solution
.
template_str
;
const
auto
blocks_per_batch
=
solution
.
grid_size
;
const
auto
block_size
=
solution
.
block_size
;
hip_compile_options
options
;
options
.
embedded_headers
=
ck_headers
;
auto
grid_size
=
can_fold_batch
?
blocks_per_batch
:
batch_count
*
blocks_per_batch
;
options
.
set_launch_params
(
v
,
grid_size
*
block_size
,
block_size
);
options
.
inputs
=
inputs
;
...
...
@@ -365,12 +304,13 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
auto
src
=
interpolate_string
(
ck_gemm_kernel
,
{{
"solution"
,
template_str
},
{
"include"
,
include_header
},
{
"params"
,
enum_params
(
inputs
.
size
(),
"void * private_p"
)},
{
"args"
,
enum_params
(
inputs
.
size
(),
"private_p"
)},
{
"blocks_per_batch"
,
to_string
(
blocks_per_batch
)},
{
"preamble"
,
v
.
get
(
"preamble"
,
std
::
string
{})},
{
"kernel"
,
options
.
kernel_name
}});
std
::
cout
<<
src
<<
std
::
endl
;
return
compile_hip_code_object
(
src
,
options
);
}
...
...
src/targets/gpu/kernels/include/migraphx/kernels/ck_gemm.hpp
View file @
d46c7224
...
...
@@ -29,7 +29,6 @@
#include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/tensor_view.hpp>
#include <migraphx/kernels/ck.hpp>
#include <migraphx/kernels/ck_gemm_includes.hpp>
#include <migraphx/kernels/gemm_batcher.hpp>
namespace
migraphx
{
...
...
src/targets/gpu/kernels/include/migraphx/kernels/ck_gemm_includes.hpp
deleted
100644 → 0
View file @
15bf2de8
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_KERNELS_CK_INCLUDES_HPP
#define MIGRAPHX_GUARD_KERNELS_CK_INCLUDES_HPP
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/algorithm.hpp>
#include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/tensor_view.hpp>
#include <ck/utility/common_header.hpp>
#include <ck/tensor_description/tensor_descriptor.hpp>
#include <ck/tensor_description/tensor_descriptor_helper.hpp>
#include <ck/tensor_operation/gpu/device/tensor_layout.hpp>
#include <ck/tensor_operation/gpu/device/device_gemm.hpp>
#include <ck/tensor_operation/gpu/device/gemm_specialization.hpp>
#include <ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp>
#include <ck/tensor_operation/gpu/device/matrix_padder.hpp>
#include <ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp>
namespace
migraphx
{
template
<
ck
::
index_t
MPerBlock
,
ck
::
index_t
NPerBlock
,
typename
CGridDesc_M_N
>
struct
BlockToCTileMap_M00_N0_M01Adapt
{
static
constexpr
auto
I0
=
ck
::
Number
<
0
>
{};
static
constexpr
auto
I1
=
ck
::
Number
<
1
>
{};
static
constexpr
auto
I2
=
ck
::
Number
<
2
>
{};
static
constexpr
auto
I3
=
ck
::
Number
<
3
>
{};
__host__
__device__
constexpr
BlockToCTileMap_M00_N0_M01Adapt
()
=
default
;
__host__
__device__
constexpr
BlockToCTileMap_M00_N0_M01Adapt
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
,
ck
::
index_t
M01
=
8
)
:
M01_
(
M01
),
c_grid_desc_m_n_
(
c_grid_desc_m_n
)
{
}
__host__
__device__
constexpr
ck
::
index_t
CalculateGridSize
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
const
{
const
auto
M0
=
ck
::
math
::
integer_divide_ceil
(
c_grid_desc_m_n
.
GetLength
(
I0
),
MPerBlock
);
const
auto
N0
=
ck
::
math
::
integer_divide_ceil
(
c_grid_desc_m_n
.
GetLength
(
I1
),
NPerBlock
);
const
ck
::
index_t
grid_size
=
M0
*
N0
;
return
grid_size
;
}
template
<
typename
TopIdx
>
__host__
__device__
constexpr
auto
CalculateBottomIndex
(
const
TopIdx
&
idx_top
)
const
{
auto
block_1d_id
=
idx_top
[
I0
];
const
auto
M0
=
ck
::
math
::
integer_divide_ceil
(
c_grid_desc_m_n_
.
GetLength
(
I0
),
MPerBlock
);
const
auto
N0
=
ck
::
math
::
integer_divide_ceil
(
c_grid_desc_m_n_
.
GetLength
(
I1
),
NPerBlock
);
block_1d_id
=
block_1d_id
%
(
M0
*
N0
);
// swallow batch index
ck
::
index_t
idx_N0
=
block_1d_id
%
N0
;
ck
::
index_t
idx_M0
=
block_1d_id
/
N0
;
const
auto
M01_adapt
=
(
idx_M0
<
M0
-
M0
%
M01_
)
?
M01_
:
M0
%
M01_
;
ck
::
index_t
idx_M00
=
idx_M0
/
M01_
;
ck
::
index_t
idx_M01
=
idx_M0
%
M01_
;
ck
::
index_t
idx_N0_M01_local
=
idx_N0
+
idx_M01
*
N0
;
return
ck
::
make_tuple
(
idx_N0_M01_local
%
M01_adapt
+
idx_M00
*
M01_
,
idx_N0_M01_local
/
M01_adapt
);
}
template
<
typename
CTileIdx
,
typename
CTileDim
>
__host__
__device__
bool
constexpr
ValidCTileIndex
(
const
CTileIdx
&
/* c_tile_idx */
,
const
CTileDim
&
/* c_tile_dim */
)
const
{
return
true
;
// always valid provided that user gets grid size from CalculateGridSize()
}
__host__
__device__
constexpr
bool
CheckValidity
(
const
CGridDesc_M_N
&
/* c_grid_desc_m_n */
)
const
{
return
true
;
}
private:
ck
::
index_t
M01_
;
CGridDesc_M_N
c_grid_desc_m_n_
;
};
template
<
typename
ALayout
,
typename
BLayout
,
typename
DsLayout
,
typename
ELayout
,
typename
ADataType
,
typename
BDataType
,
typename
AccDataType
,
typename
CShuffleDataType
,
typename
DsDataType
,
typename
EDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CDEElementwiseOperation
,
ck
::
tensor_operation
::
device
::
GemmSpecialization
GemmSpec
,
ck
::
index_t
NumGemmKPrefetchStage
,
ck
::
index_t
BlockSize
,
ck
::
index_t
MPerBlock
,
ck
::
index_t
NPerBlock
,
ck
::
index_t
KPerBlock
,
ck
::
index_t
AK1
,
ck
::
index_t
BK1
,
ck
::
index_t
MPerXDL
,
ck
::
index_t
NPerXDL
,
ck
::
index_t
MXdlPerWave
,
ck
::
index_t
NXdlPerWave
,
typename
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
ck
::
index_t
ABlockTransferSrcVectorDim
,
ck
::
index_t
ABlockTransferSrcScalarPerVector
,
ck
::
index_t
ABlockTransferDstScalarPerVector_AK1
,
ck
::
index_t
ABlockLdsExtraM
,
typename
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
typename
BBlockTransferThreadClusterArrangeOrder
,
typename
BBlockTransferSrcAccessOrder
,
ck
::
index_t
BBlockTransferSrcVectorDim
,
ck
::
index_t
BBlockTransferSrcScalarPerVector
,
ck
::
index_t
BBlockTransferDstScalarPerVector_BK1
,
ck
::
index_t
BBlockLdsExtraN
,
ck
::
index_t
CShuffleMXdlPerWavePerShuffle
,
ck
::
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
ck
::
index_t
CDEBlockTransferScalarPerVector_NPerBlock
,
ck
::
LoopScheduler
LoopSched
=
ck
::
make_default_loop_scheduler
(),
ck
::
PipelineVersion
PipelineVer
=
ck
::
PipelineVersion
::
v1
>
struct
CK_DeviceGemmMultipleD
{
static
constexpr
auto
I0
=
ck
::
Number
<
0
>
{};
static
constexpr
auto
I1
=
ck
::
Number
<
1
>
{};
// static constexpr auto I2 = ck::Number<2>{};
// static constexpr auto I3 = ck::Number<3>{};
// static constexpr auto I4 = ck::Number<4>{};
// static constexpr auto I5 = ck::Number<5>{};
// static constexpr auto I6 = ck::Number<6>{};
// static constexpr auto I7 = ck::Number<7>{};
ck
::
tensor_operation
::
device
::
MatrixPadder
<
GemmSpec
,
ck
::
index_t
,
ck
::
index_t
,
ck
::
index_t
>
matrix_padder
{
MPerBlock
,
NPerBlock
,
KPerBlock
};
// GridwiseGemm
using
GridwiseGemm
=
ck
::
GridwiseGemmMultipleD_xdl_cshuffle
<
ADataType
,
// TODO: distinguish A/B datatype
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
,
ck
::
InMemoryDataOperationEnum
::
Set
,
NumGemmKPrefetchStage
,
BlockSize
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
AK1
,
BK1
,
MPerXDL
,
NPerXDL
,
MXdlPerWave
,
NXdlPerWave
,
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcVectorDim
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_AK1
,
false
,
ABlockLdsExtraM
,
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcVectorDim
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_BK1
,
false
,
BBlockLdsExtraN
,
CShuffleMXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
CDEBlockTransferScalarPerVector_NPerBlock
,
LoopSched
,
PipelineVer
>
;
// return block_id to E matrix tile idx (m0, n0) mapping
template
<
class
EGridDesc_M_N
>
__device__
static
constexpr
auto
MakeDefaultBlock2ETileMap
(
const
EGridDesc_M_N
&
e_grid_desc_m_n_
)
{
return
BlockToCTileMap_M00_N0_M01Adapt
<
MPerBlock
,
NPerBlock
,
EGridDesc_M_N
>
(
e_grid_desc_m_n_
);
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template
<
typename
AGridDesc_M_K
,
typename
BGridDesc_N_K
,
typename
DsGridDesc_M_N
,
typename
EGridDesc_M_N
,
typename
Block2ETileMap
>
static
constexpr
bool
CheckValidity
(
const
AGridDesc_M_K
&
a_grid_desc_m_k
,
const
BGridDesc_N_K
&
b_grid_desc_n_k
,
const
DsGridDesc_M_N
&
ds_grid_desc_m_n
,
const
EGridDesc_M_N
&
e_grid_desc_m_n
,
const
Block2ETileMap
&
block_2_etile_map
)
{
const
auto
M
=
a_grid_desc_m_k
.
GetLength
(
I0
);
const
auto
N
=
b_grid_desc_n_k
.
GetLength
(
I0
);
const
auto
K
=
a_grid_desc_m_k
.
GetLength
(
I1
);
// check consistency of desc
MIGRAPHX_CHECK
(
M
==
e_grid_desc_m_n
.
GetLength
(
I0
)
&&
N
==
e_grid_desc_m_n
.
GetLength
(
I1
));
// check tile size
MIGRAPHX_CHECK
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K
%
KPerBlock
==
0
);
// check block-to-E-tile
MIGRAPHX_CHECK
(
block_2_etile_map
.
CheckValidity
(
e_grid_desc_m_n
));
return
GridwiseGemm
::
CheckValidity
(
a_grid_desc_m_k
,
b_grid_desc_n_k
,
ds_grid_desc_m_n
,
e_grid_desc_m_n
,
block_2_etile_map
);
}
AElementwiseOperation
a_element_op
{};
BElementwiseOperation
b_element_op
{};
CDEElementwiseOperation
cde_element_op
{};
};
}
// namespace migraphx
#endif
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment