Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
cb17765e
Commit
cb17765e
authored
Dec 08, 2022
by
rocking
Browse files
Check argument for gemm
parent
eaeef340
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
89 additions
and
5 deletions
+89
-5
include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp
.../device/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp
+89
-5
No files found.
include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp
View file @
cb17765e
...
@@ -540,7 +540,9 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
...
@@ -540,7 +540,9 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
b_element_op_
{
b_element_op
},
b_element_op_
{
b_element_op
},
cde_element_op_
{
cde_element_op
},
cde_element_op_
{
cde_element_op
},
h_element_op_
{
h_element_op
},
h_element_op_
{
h_element_op
},
MRaw_
(
MRaw
),
MRaw_
{
MRaw
},
NRaw_
{
NRaw
},
KRaw_
{
KRaw
},
gemm_nblock_
{
math
::
integer_divide_ceil
(
NRaw
,
NPerBlock
)},
gemm_nblock_
{
math
::
integer_divide_ceil
(
NRaw
,
NPerBlock
)},
epsilon_
{
epsilon
}
epsilon_
{
epsilon
}
{
{
...
@@ -638,8 +640,10 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
...
@@ -638,8 +640,10 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
CDEElementwiseOperation
cde_element_op_
;
CDEElementwiseOperation
cde_element_op_
;
HElementwiseOperation
h_element_op_
;
HElementwiseOperation
h_element_op_
;
int
MRaw_
;
index_t
MRaw_
;
int
gemm_nblock_
;
index_t
NRaw_
;
index_t
KRaw_
;
index_t
gemm_nblock_
;
AccDataType
epsilon_
;
AccDataType
epsilon_
;
};
};
...
@@ -829,14 +833,94 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
...
@@ -829,14 +833,94 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
reinterpret_cast
<
char
*>
(
pArg_
->
p_workspace_var_
)
+
variance_space_sz
;
reinterpret_cast
<
char
*>
(
pArg_
->
p_workspace_var_
)
+
variance_space_sz
;
};
};
static
bool
IsSupportedArgument
(
const
Argument
&
)
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
{
if
(
!
(
ck
::
get_device_name
()
==
"gfx908"
||
ck
::
get_device_name
()
==
"gfx90a"
))
if
(
!
(
ck
::
get_device_name
()
==
"gfx908"
||
ck
::
get_device_name
()
==
"gfx90a"
))
{
{
return
false
;
return
false
;
}
}
// TODO
// check vector load/store
{
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
// check vector load of A
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
ABlockTransferSrcVectorDim
==
2
)
{
if
(
arg
.
KRaw_
%
ABlockTransferSrcScalarPerVector
!=
0
)
{
return
false
;
}
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Col
>
&&
ABlockTransferSrcVectorDim
==
1
)
{
// FIXME: not rigorous
if
(
arg
.
MRaw_
%
ABlockTransferSrcScalarPerVector
!=
0
)
{
return
false
;
}
}
else
{
return
false
;
}
// check vector laod of B
if
constexpr
(
is_same_v
<
BLayout
,
Col
>
&&
BBlockTransferSrcVectorDim
==
2
)
{
if
(
arg
.
KRaw_
%
BBlockTransferSrcScalarPerVector
!=
0
)
{
return
false
;
}
}
else
if
constexpr
(
is_same_v
<
BLayout
,
Row
>
&&
BBlockTransferSrcVectorDim
==
1
)
{
// FIXME: not rigorous
if
(
arg
.
NRaw_
%
BBlockTransferSrcScalarPerVector
!=
0
)
{
return
false
;
}
}
else
{
return
false
;
}
// check vector load of Ds
// only support RowMajor for now
bool
all_valid
=
true
;
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
using
DLayout
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsLayout
>>
;
if
constexpr
(
!
is_same_v
<
DLayout
,
Row
>
)
{
all_valid
=
false
;
}
});
if
(
!
all_valid
)
{
return
false
;
}
// check vector store of E
// only support RowMajor for now
if
constexpr
(
is_same_v
<
ELayout
,
Row
>
)
{
if
(
arg
.
NRaw_
%
PostShuffleScalarPerVector
!=
0
)
{
return
false
;
}
}
else
{
return
false
;
}
}
// TODO - layernorm
return
true
;
return
true
;
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment