Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
7fd5e9f5
Commit
7fd5e9f5
authored
Jun 13, 2022
by
Chao Liu
Browse files
refactor
parent
e09f6e02
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
23 additions
and
13 deletions
+23
-13
example/03_gemm_bias_add_fastgelu/gemm_bias_add_fastgelu_xdl_fp16.cpp
...emm_bias_add_fastgelu/gemm_bias_add_fastgelu_xdl_fp16.cpp
+15
-8
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp
...ration/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp
+8
-5
No files found.
example/03_gemm_bias_add_fastgelu/gemm_bias_add_fastgelu_xdl_fp16.cpp
View file @
7fd5e9f5
...
@@ -26,19 +26,26 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;
...
@@ -26,19 +26,26 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
// E = FastGelu((A * B) + D0 + D1)
// C = A * B
struct
AddAddFastGelu
struct
AddAddFastGelu
{
{
// https://paperswithcode.com/method/gelu
// y = 0.5*x*(1+tanh(sqrt(2/pi)*(x+0.044715*x^3)))
__host__
__device__
void
__host__
__device__
void
operator
()(
ck
::
half_t
&
y
,
const
float
&
x0
,
const
ck
::
half_t
&
x1
,
const
ck
::
half_t
&
x2
)
const
operator
()(
ck
::
half_t
&
e
,
const
float
&
c
,
const
ck
::
half_t
&
d0
,
const
ck
::
half_t
&
d1
)
const
{
{
const
float
x
=
x0
+
x1
+
x2
;
// Fast GeLU
const
float
u
=
float
(
2
)
*
x
*
(
float
(
0.035677
)
*
x
*
x
+
float
(
0.797885
));
// https://paperswithcode.com/method/gelu
const
float
emu
=
exp
(
-
u
);
// y = 0.5*x*(1+tanh(sqrt(2/pi)*(x+0.044715*x^3)))
const
float
cdf
=
float
(
0.5
)
+
float
(
0.5
)
*
(
float
(
2
)
/
(
float
(
1
)
+
emu
)
-
float
(
1
));
const
auto
fast_gelu
=
[
&
](
float
x
)
{
const
float
u
=
float
(
2
)
*
x
*
(
float
(
0.035677
)
*
x
*
x
+
float
(
0.797885
));
const
float
emu
=
exp
(
-
u
);
const
float
cdf
=
float
(
0.5
)
+
float
(
0.5
)
*
(
float
(
2
)
/
(
float
(
1
)
+
emu
)
-
float
(
1
));
return
x
*
cdf
;
};
const
float
y
=
fast_gelu
(
c
+
float
(
d0
)
+
float
(
d1
));
y
=
ck
::
type_convert
<
ck
::
half_t
>
(
x
*
cdf
);
e
=
ck
::
type_convert
<
ck
::
half_t
>
(
y
);
}
}
};
};
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp
View file @
7fd5e9f5
...
@@ -567,8 +567,8 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
...
@@ -567,8 +567,8 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
Sequence
<
0
,
1
,
2
,
3
>
,
// typename DimAccessOrder,
Sequence
<
0
,
1
,
2
,
3
>
,
// typename DimAccessOrder,
3
,
// index_t VectorDim,
3
,
// index_t VectorDim,
CDEShuffleBlockTransferScalarPerVector_NPerBlock
,
// index_t ScalarPerVector,
CDEShuffleBlockTransferScalarPerVector_NPerBlock
,
// index_t ScalarPerVector,
Sequence
<
true
,
false
,
false
>
,
//
bool
ThreadTransferSrcResetCoordinateAfterRunFlags
Sequence
<
true
,
false
,
false
>
,
// ThreadTransferSrcResetCoordinateAfterRunFlags
Sequence
<
false
>>
//
bool
ThreadTransferDstResetCoordinateAfterRunFlags
Sequence
<
false
>>
// ThreadTransferDstResetCoordinateAfterRunFlags
{
c_ds_descs
,
{
c_ds_descs
,
make_tuple
(
make_multi_index
(
0
,
0
,
0
,
0
),
make_tuple
(
make_multi_index
(
0
,
0
,
0
,
0
),
make_multi_index
(
block_work_idx
[
I0
],
0
,
block_work_idx
[
I1
],
0
),
make_multi_index
(
block_work_idx
[
I0
],
0
,
block_work_idx
[
I1
],
0
),
...
@@ -626,17 +626,20 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
...
@@ -626,17 +626,20 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
if
constexpr
(
access_id
<
num_access
-
1
)
if
constexpr
(
access_id
<
num_access
-
1
)
{
{
constexpr
auto
e_global_step
=
sfc_cde_block
.
GetForwardStep
(
access_id
);
constexpr
auto
cde_lds_and_global_step
=
sfc_cde_block
.
GetForwardStep
(
access_id
);
// move on Ds
// move on Ds
static_for
<
0
,
DsDataType
::
Size
(),
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
DsDataType
::
Size
(),
1
>
{}([
&
](
auto
i
)
{
cde_block_copy_lds_and_global
.
MoveSrcSliceWindow
(
cde_block_copy_lds_and_global
.
MoveSrcSliceWindow
(
c_ds_descs
,
i
+
I1
,
e
_global_step
);
c_ds_descs
,
i
+
I1
,
cde_lds_and
_global_step
);
});
});
// move on E
// move on E
cde_block_copy_lds_and_global
.
MoveDstSliceWindow
(
cde_block_copy_lds_and_global
.
MoveDstSliceWindow
(
tie
(
e_grid_desc_mblock_mperblock_nblock_nperblock
),
I0
,
e_global_step
);
tie
(
e_grid_desc_mblock_mperblock_nblock_nperblock
),
I0
,
cde_lds_and_global_step
);
}
}
});
});
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment