Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
3b065199
Commit
3b065199
authored
Oct 12, 2024
by
letaoqin
Browse files
adding readme and add op
parent
244e313f
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
44 additions
and
6 deletions
+44
-6
example/66_gemm_bias_activation/README.md
example/66_gemm_bias_activation/README.md
+27
-0
example/66_gemm_bias_activation/gemm_bias_add.hpp
example/66_gemm_bias_activation/gemm_bias_add.hpp
+2
-0
example/66_gemm_bias_activation/gemm_bias_add_fp16.cpp
example/66_gemm_bias_activation/gemm_bias_add_fp16.cpp
+5
-0
example/66_gemm_bias_activation/gemm_bias_add_xdl_fp16.cpp
example/66_gemm_bias_activation/gemm_bias_add_xdl_fp16.cpp
+10
-6
No files found.
example/66_gemm_bias_activation/README.md
0 → 100644
View file @
3b065199
### Build
```
mkdir -p build
cd build
sh ../script/cmake-ck-dev.sh ../ <arch> # you can replace this <arch> to gfx90a, gfx942...
make -j example_gemm_bias_add_xdl_fp16
```
### Run Examples
#### args:
```
#arg1: verification (0=no, 1=yes)
#arg2: initialization (0=no init, 1=integer value, 2=decimal value)
#arg3: time kernel (0=no, 1=yes)
arg4 to 7: M (256x), N(128x), K(32x), op_type(Add = 0, Gelu = 1, Relu = 2, Silu = 3, Sigmoid = 4)
```
#### command:
```
./build/bin/example_gemm_bias_add_xdl_fp16
./build/bin/example_gemm_bias_add_xdl_fp16 1 1 1 64 3072 768 0
./build/bin/example_gemm_bias_add_xdl_fp16 1 1 1 64 3072 768 1
./build/bin/example_gemm_bias_add_xdl_fp16 1 1 1 64 3072 768 2
./build/bin/example_gemm_bias_add_xdl_fp16 1 1 1 64 3072 768 3
./build/bin/example_gemm_bias_add_xdl_fp16 1 1 1 64 3072 768 4
```
example/66_gemm_bias_activation/gemm_bias_add.hpp
View file @
3b065199
...
@@ -66,6 +66,7 @@ struct AddActivation
...
@@ -66,6 +66,7 @@ struct AddActivation
}
// namespace ck
}
// namespace ck
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
Add
=
ck
::
tensor_operation
::
element_wise
::
Add
;
using
Gelu
=
ck
::
tensor_operation
::
element_wise
::
Gelu
;
using
Gelu
=
ck
::
tensor_operation
::
element_wise
::
Gelu
;
using
Relu
=
ck
::
tensor_operation
::
element_wise
::
Relu
;
using
Relu
=
ck
::
tensor_operation
::
element_wise
::
Relu
;
using
Silu
=
ck
::
tensor_operation
::
element_wise
::
Silu
;
using
Silu
=
ck
::
tensor_operation
::
element_wise
::
Silu
;
...
@@ -82,6 +83,7 @@ struct GemmBiasAddArgs
...
@@ -82,6 +83,7 @@ struct GemmBiasAddArgs
ck
::
index_t
K
;
ck
::
index_t
K
;
};
};
float
gemm_bias_add_fp16
(
const
GemmBiasAddArgs
&
args
,
const
StreamConfig
&
config
);
float
gemm_bias_add_relu_fp16
(
const
GemmBiasAddArgs
&
args
,
const
StreamConfig
&
config
);
float
gemm_bias_add_relu_fp16
(
const
GemmBiasAddArgs
&
args
,
const
StreamConfig
&
config
);
float
gemm_bias_add_gelu_fp16
(
const
GemmBiasAddArgs
&
args
,
const
StreamConfig
&
config
);
float
gemm_bias_add_gelu_fp16
(
const
GemmBiasAddArgs
&
args
,
const
StreamConfig
&
config
);
float
gemm_bias_add_silu_fp16
(
const
GemmBiasAddArgs
&
args
,
const
StreamConfig
&
config
);
float
gemm_bias_add_silu_fp16
(
const
GemmBiasAddArgs
&
args
,
const
StreamConfig
&
config
);
...
...
example/66_gemm_bias_activation/gemm_bias_add_fp16.cpp
View file @
3b065199
...
@@ -140,6 +140,11 @@ float run_impl(const GemmBiasAddArgs& args, const StreamConfig& config)
...
@@ -140,6 +140,11 @@ float run_impl(const GemmBiasAddArgs& args, const StreamConfig& config)
return
ave_time
;
return
ave_time
;
}
}
float
gemm_bias_add_fp16
(
const
GemmBiasAddArgs
&
args
,
const
StreamConfig
&
config
)
{
return
run_impl
<
Add
>
(
args
,
config
);
}
float
gemm_bias_add_relu_fp16
(
const
GemmBiasAddArgs
&
args
,
const
StreamConfig
&
config
)
float
gemm_bias_add_relu_fp16
(
const
GemmBiasAddArgs
&
args
,
const
StreamConfig
&
config
)
{
{
return
run_impl
<
ck
::
impl
::
AddActivation
<
Relu
>>
(
args
,
config
);
return
run_impl
<
ck
::
impl
::
AddActivation
<
Relu
>>
(
args
,
config
);
...
...
example/66_gemm_bias_activation/gemm_bias_add_xdl_fp16.cpp
View file @
3b065199
...
@@ -172,8 +172,8 @@ int main(int argc, char* argv[])
...
@@ -172,8 +172,8 @@ int main(int argc, char* argv[])
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg3: time kernel (0=no, 1=yes)
\n
"
);
printf
(
"arg3: time kernel (0=no, 1=yes)
\n
"
);
printf
(
"arg4 to
9
: M (256x), N(128x), K(32x)m, op_type(Gelu =
0
, Relu =
1
, Silu =
2,
"
printf
(
"arg4 to
7
: M (256x), N(128x), K(32x)m, op_type(
Add = 0,
Gelu =
1
, Relu =
2
, Silu = "
"Sigmoid =
3
\n
"
);
"
3,
Sigmoid =
4
\n
"
);
exit
(
0
);
exit
(
0
);
}
}
...
@@ -240,10 +240,12 @@ int main(int argc, char* argv[])
...
@@ -240,10 +240,12 @@ int main(int argc, char* argv[])
float
ave_time
=
0
;
float
ave_time
=
0
;
if
(
op_type
==
0
)
if
(
op_type
==
0
)
ave_time
=
gemm_bias_add_
gelu_
fp16
(
gemm_args
,
StreamConfig
{
nullptr
,
time_kernel
,
20
,
50
});
ave_time
=
gemm_bias_add_fp16
(
gemm_args
,
StreamConfig
{
nullptr
,
time_kernel
,
20
,
50
});
else
if
(
op_type
==
1
)
else
if
(
op_type
==
1
)
ave_time
=
gemm_bias_add_
r
elu_fp16
(
gemm_args
,
StreamConfig
{
nullptr
,
time_kernel
,
20
,
50
});
ave_time
=
gemm_bias_add_
g
elu_fp16
(
gemm_args
,
StreamConfig
{
nullptr
,
time_kernel
,
20
,
50
});
else
if
(
op_type
==
2
)
else
if
(
op_type
==
2
)
ave_time
=
gemm_bias_add_relu_fp16
(
gemm_args
,
StreamConfig
{
nullptr
,
time_kernel
,
20
,
50
});
else
if
(
op_type
==
3
)
ave_time
=
gemm_bias_add_silu_fp16
(
gemm_args
,
StreamConfig
{
nullptr
,
time_kernel
,
20
,
50
});
ave_time
=
gemm_bias_add_silu_fp16
(
gemm_args
,
StreamConfig
{
nullptr
,
time_kernel
,
20
,
50
});
else
else
ave_time
=
ave_time
=
...
@@ -283,10 +285,12 @@ int main(int argc, char* argv[])
...
@@ -283,10 +285,12 @@ int main(int argc, char* argv[])
}
}
};
};
if
(
op_type
==
0
)
if
(
op_type
==
0
)
run_elementwise
(
ck
::
impl
::
AddActivation
<
Gelu
>
{});
run_elementwise
(
Add
{});
else
if
(
op_type
==
1
)
else
if
(
op_type
==
1
)
run_elementwise
(
ck
::
impl
::
AddActivation
<
R
elu
>
{});
run_elementwise
(
ck
::
impl
::
AddActivation
<
G
elu
>
{});
else
if
(
op_type
==
2
)
else
if
(
op_type
==
2
)
run_elementwise
(
ck
::
impl
::
AddActivation
<
Relu
>
{});
else
if
(
op_type
==
3
)
run_elementwise
(
ck
::
impl
::
AddActivation
<
Silu
>
{});
run_elementwise
(
ck
::
impl
::
AddActivation
<
Silu
>
{});
else
else
run_elementwise
(
ck
::
impl
::
AddActivation
<
Sigmoid
>
{});
run_elementwise
(
ck
::
impl
::
AddActivation
<
Sigmoid
>
{});
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment