Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
4944e46f
Commit
4944e46f
authored
Jul 19, 2023
by
turneram
Browse files
Use variables for conditions
parent
d8f97e5b
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
19 additions
and
17 deletions
+19
-17
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp
...ration/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp
+19
-17
No files found.
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp
View file @
4944e46f
...
@@ -271,9 +271,10 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
...
@@ -271,9 +271,10 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
const
auto
K
=
a_grid_desc_m_k
.
GetLength
(
I1
);
const
auto
K
=
a_grid_desc_m_k
.
GetLength
(
I1
);
// check consistency of desc
// check consistency of desc
if
(
!
(
M
==
e_grid_desc_m_n
.
GetLength
(
I0
)
&&
N
==
e_grid_desc_m_n
.
GetLength
(
I1
)))
constexpr
bool
cond1
=
(
M
==
e_grid_desc_m_n
.
GetLength
(
I0
)
&&
N
==
e_grid_desc_m_n
.
GetLength
(
I1
));
if
(
!
cond1
)
{
{
static_assert
(
(
M
==
e_grid_desc_m_n
.
GetLength
(
I0
)
&&
N
==
e_grid_desc_m_n
.
GetLength
(
I1
))
,
"e_grid_desc invalid
\n
"
);
static_assert
(
cond1
,
"e_grid_desc invalid
\n
"
);
return
false
;
return
false
;
}
}
...
@@ -284,46 +285,47 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
...
@@ -284,46 +285,47 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
N
==
ds_grid_desc_m_n
[
i
].
GetLength
(
I1
));
N
==
ds_grid_desc_m_n
[
i
].
GetLength
(
I1
));
});
});
if
(
!
valid
)
constexpr
bool
cond2
=
valid
;
if
(
!
cond2
)
{
{
static_assert
(
valid
,
"ds_grid_desc invalid
\n
"
);
static_assert
(
cond2
,
"ds_grid_desc invalid
\n
"
);
return
false
;
return
false
;
}
}
// check tile size
// check tile size
if
(
!
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K
%
KPerBlock
==
0
))
constexpr
bool
cond3
=
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K
%
KPerBlock
==
0
);
if
(
!
cond3
)
{
{
static_assert
(
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K
%
KPerBlock
==
0
)
,
"tile size invalid
\n
"
);
static_assert
(
cond3
,
"tile size invalid
\n
"
);
return
false
;
return
false
;
}
}
// check gridwise gemm pipeline
// check gridwise gemm pipeline
const
auto
num_k_loop
=
K
/
KPerBlock
;
const
auto
num_k_loop
=
K
/
KPerBlock
;
constexpr
bool
cond4
=
GridwiseGemmPipe
::
IsSupported
(
num_k_loop
);
if
(
!
GridwiseGemmPipe
::
IsSupported
(
num_k_loop
)
)
if
(
!
cond4
)
{
{
static_assert
(
GridwiseGemmPipe
::
IsSupported
(
num_k_loop
)
,
"num_k_loop invalid
\n
"
);
static_assert
(
cond4
,
"num_k_loop invalid
\n
"
);
return
false
;
return
false
;
}
}
// check block-to-E-tile
// check block-to-E-tile
if
(
!
block_2_etile_map
.
CheckValidity
(
e_grid_desc_m_n
))
constexpr
bool
cond5
=
block_2_etile_map
.
CheckValidity
(
e_grid_desc_m_n
);
if
(
!
cond5
)
{
{
static_assert
(
block_2_etile_map
.
CheckValidity
(
e_grid_desc_m_n
)
,
"block_2_etile_map invalid
\n
"
);
static_assert
(
cond5
,
"block_2_etile_map invalid
\n
"
);
return
false
;
return
false
;
}
}
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
// check tensor size: cannot be larger than 2GB each
// check tensor size: cannot be larger than 2GB each
constexpr
long_index_t
TwoGB
=
(
long_index_t
{
1
}
<<
31
);
constexpr
long_index_t
TwoGB
=
(
long_index_t
{
1
}
<<
31
);
constexpr
bool
cond6
=
(
a_grid_desc_m_k
.
GetElementSpaceSize
()
*
sizeof
(
ABDataType
)
<=
TwoGB
&&
if
(
!
(
a_grid_desc_m_k
.
GetElementSpaceSize
()
*
sizeof
(
ABDataType
)
<=
TwoGB
&&
b_grid_desc_n_k
.
GetElementSpaceSize
()
*
sizeof
(
ABDataType
)
<=
TwoGB
&&
b_grid_desc_n_k
.
GetElementSpaceSize
()
*
sizeof
(
ABDataType
)
<=
TwoGB
&&
e_grid_desc_m_n
.
GetElementSpaceSize
()
*
sizeof
(
EDataType
)
<=
TwoGB
))
e_grid_desc_m_n
.
GetElementSpaceSize
()
*
sizeof
(
EDataType
)
<=
TwoGB
);
if
(
!
cond6
)
{
{
static_assert
((
a_grid_desc_m_k
.
GetElementSpaceSize
()
*
sizeof
(
ABDataType
)
<=
TwoGB
&&
static_assert
(
cond6
,
"invalid tensor (> 2GB)
\n
"
);
b_grid_desc_n_k
.
GetElementSpaceSize
()
*
sizeof
(
ABDataType
)
<=
TwoGB
&&
e_grid_desc_m_n
.
GetElementSpaceSize
()
*
sizeof
(
EDataType
)
<=
TwoGB
),
"invalid tensor (> 2GB)
\n
"
);
return
false
;
return
false
;
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment