Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
da8c0608
Commit
da8c0608
authored
May 30, 2022
by
rocking
Browse files
[What] Suport non pointer for invoker and argument
[Why] Snyc coding style with gemm
parent
2fc2a189
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
55 additions
and
21 deletions
+55
-21
example/21_gemm_normalize_xdl/gemm_layernorm_xdl_fp16.cpp
example/21_gemm_normalize_xdl/gemm_layernorm_xdl_fp16.cpp
+20
-20
include/ck/tensor_operation/gpu/device/device_5ary_elementwise_xdl_cshuffle.hpp
...ation/gpu/device/device_5ary_elementwise_xdl_cshuffle.hpp
+35
-1
No files found.
example/21_gemm_normalize_xdl/gemm_layernorm_xdl_fp16.cpp
View file @
da8c0608
...
...
@@ -275,14 +275,14 @@ int main()
// Prepare LayerNorm
auto
normalize
=
DeviceNormalizeInstance
{};
auto
normalize_invoker
_ptr
=
normalize
.
MakeInvoker
Pointer
();
auto
normalize_argument
=
normalize
.
MakeArgumentPointer
(
c_device_buf
.
GetDeviceBuffer
(),
reduceMean_device_buf
.
GetDeviceBuffer
(),
reduceMeanSquare_device_buf
.
GetDeviceBuffer
(),
gamma_device_buf
.
GetDeviceBuffer
(),
beta_device_buf
.
GetDeviceBuffer
(),
layerNorm_device_buf
.
GetDeviceBuffer
(),
auto
normalize_invoker
=
normalize
.
MakeInvoker
();
auto
normalize_argument
=
normalize
.
MakeArgument
(
static_cast
<
CDataType
*>
(
c_device_buf
.
GetDeviceBuffer
()
)
,
static_cast
<
DDataType
*>
(
reduceMean_device_buf
.
GetDeviceBuffer
()
)
,
static_cast
<
DDataType
*>
(
reduceMeanSquare_device_buf
.
GetDeviceBuffer
()
)
,
static_cast
<
GammaDataType
*>
(
gamma_device_buf
.
GetDeviceBuffer
()
)
,
static_cast
<
BetaDataType
*>
(
beta_device_buf
.
GetDeviceBuffer
()
)
,
static_cast
<
LayerNormOutDataType
*>
(
layerNorm_device_buf
.
GetDeviceBuffer
()
)
,
{
M
,
N
},
{
StrideC
,
1
},
{
1
,
0
},
...
...
@@ -292,7 +292,7 @@ int main()
{
StrideC
,
1
},
NormalizeFunctor
{});
if
(
!
normalize
.
IsSupportedArgument
(
normalize_argument
.
get
()
))
if
(
!
normalize
.
IsSupportedArgument
(
normalize_argument
))
{
throw
std
::
runtime_error
(
"The runtime parameters seems not supported by the "
"Device5AryElementwise_Xdl_CShuffle instance, exiting!"
);
...
...
@@ -300,7 +300,7 @@ int main()
// run kernel
gemmReduce_invoker
.
Run
(
gemmReduce_argument
,
StreamConfig
{
nullptr
,
time_kernel
});
normalize_invoker
_ptr
->
Run
(
normalize_argument
.
get
()
,
StreamConfig
{
nullptr
,
time_kernel
});
normalize_invoker
.
Run
(
normalize_argument
,
StreamConfig
{
nullptr
,
time_kernel
});
bool
pass
=
true
;
{
...
...
include/ck/tensor_operation/gpu/device/device_5ary_elementwise_xdl_cshuffle.hpp
View file @
da8c0608
...
...
@@ -215,6 +215,8 @@ struct Device5AryElementwise_Xdl_CShuffle : public BaseOperator
}
};
bool
IsSupportedArgument
(
const
BaseArgument
&
p_arg
)
{
return
IsSupportedArgument
(
&
p_arg
);
}
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
{
const
Argument
*
pArg
=
dynamic_cast
<
const
Argument
*>
(
p_arg
);
...
...
@@ -260,6 +262,37 @@ struct Device5AryElementwise_Xdl_CShuffle : public BaseOperator
return
true
;
};
static
auto
MakeArgument
(
const
ADataType
*
p_a
,
const
BDataType
*
p_b
,
const
CDataType
*
p_c
,
const
DDataType
*
p_d
,
const
EDataType
*
p_e
,
FDataType
*
p_f
,
std
::
vector
<
index_t
>
lengths
,
std
::
vector
<
index_t
>
a_strides
,
std
::
vector
<
index_t
>
b_strides
,
std
::
vector
<
index_t
>
c_strides
,
std
::
vector
<
index_t
>
d_strides
,
std
::
vector
<
index_t
>
e_strides
,
std
::
vector
<
index_t
>
f_strides
,
ElementwiseFunctor
functor
)
{
return
Argument
{
p_a
,
p_b
,
p_c
,
p_d
,
p_e
,
p_f
,
lengths
,
a_strides
,
b_strides
,
c_strides
,
d_strides
,
e_strides
,
f_strides
,
functor
};
}
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
const
void
*
p_c
,
...
...
@@ -291,8 +324,9 @@ struct Device5AryElementwise_Xdl_CShuffle : public BaseOperator
functor
);
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
{
return
std
::
make_unique
<
Invoker
>
();
}
};
};
// namespace device
}
// namespace device
}
// namespace tensor_operation
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment