Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
ad77ce8e
Commit
ad77ce8e
authored
Nov 21, 2022
by
letaoqin
Browse files
fix for passthrough element op
parent
f820c621
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
98 additions
and
97 deletions
+98
-97
include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_multiple_d.hpp
...tensor_operation/gpu/grid/gridwise_gemm_dl_multiple_d.hpp
+98
-97
No files found.
include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_multiple_d.hpp
View file @
ad77ce8e
...
@@ -544,115 +544,116 @@ struct GridwiseGemmDlMultipleD_km_kn_mn
...
@@ -544,115 +544,116 @@ struct GridwiseGemmDlMultipleD_km_kn_mn
blockwise_gemm
.
CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1
(
blockwise_gemm
.
CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1
(
get_thread_local_1d_id
());
get_thread_local_1d_id
());
const
auto
ds_grid_buf
=
generate_tuple
(
if
constexpr
(
!
is_same_v
<
CDEElementwiseOperation
,
[
&
](
auto
i
)
{
ck
::
tensor_operation
::
element_wise
::
PassThrough
>
)
return
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
{
p_ds_grid
[
i
],
ds_grid_desc_m0_m10_m11_n0_n10_n11
[
i
].
GetElementSpaceSize
());
const
auto
ds_grid_buf
=
generate_tuple
(
},
[
&
](
auto
i
)
{
Number
<
NumDTensor
>
{});
return
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_ds_grid
[
i
],
auto
ds_thread_buf
=
generate_tuple
(
ds_grid_desc_m0_m10_m11_n0_n10_n11
[
i
].
GetElementSpaceSize
());
[
&
](
auto
i
)
{
},
using
DDataType
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsDataType
>>
;
Number
<
NumDTensor
>
{});
return
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
auto
ds_thread_buf
=
generate_tuple
(
DDataType
,
[
&
](
auto
i
)
{
c_m10_m11_n10_n11_thread_tensor_lengths
[
I3
],
using
DDataType
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsDataType
>>
;
true
>
{};
},
return
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
Number
<
NumDTensor
>
{});
DDataType
,
c_m10_m11_n10_n11_thread_tensor_lengths
[
I3
],
auto
ds_threadwise_copy
=
generate_tuple
(
true
>
{};
[
&
](
auto
i
)
{
},
using
DDataType
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsDataType
>>
;
Number
<
NumDTensor
>
{});
return
ThreadwiseTensorSliceTransfer_v2
<
auto
ds_threadwise_copy
=
generate_tuple
(
DDataType
,
[
&
](
auto
i
)
{
DDataType
,
using
DDataType
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsDataType
>>
;
decltype
(
ds_grid_desc_m0_m10_m11_n0_n10_n11
[
i
]),
decltype
(
c_thread_desc_m0_m10_m11_n0_n10_n11
),
return
ThreadwiseTensorSliceTransfer_v2
<
Sequence
<
I1
,
DDataType
,
I1
,
DDataType
,
I1
,
decltype
(
ds_grid_desc_m0_m10_m11_n0_n10_n11
[
i
]),
I1
,
decltype
(
c_thread_desc_m0_m10_m11_n0_n10_n11
),
I1
,
Sequence
<
I1
,
Number
<
c_m10_m11_n10_n11_thread_tensor_lengths
[
I3
]
>
{}
>
,
I1
,
CThreadTransferSrcDstAccessOrder
,
I1
,
CThreadTransferSrcDstVectorDim
,
I1
,
CThreadTransferDstScalarPerVector
,
I1
,
1
,
Number
<
c_m10_m11_n10_n11_thread_tensor_lengths
[
I3
]
>
{}
>
,
false
>
(
CThreadTransferSrcDstAccessOrder
,
ds_grid_desc_m0_m10_m11_n0_n10_n11
[
i
],
CThreadTransferSrcDstVectorDim
,
make_multi_index
(
CThreadTransferDstScalarPerVector
,
im0
,
1
,
c_m10_m11_n10_n11_thread_origin_idx_on_block
[
I0
],
false
>
(
c_m10_m11_n10_n11_thread_origin_idx_on_block
[
I1
],
ds_grid_desc_m0_m10_m11_n0_n10_n11
[
i
],
in0
,
make_multi_index
(
im0
,
c_m10_m11_n10_n11_thread_origin_idx_on_block
[
I2
],
c_m10_m11_n10_n11_thread_origin_idx_on_block
[
I0
],
c_m10_m11_n10_n11_thread_origin_idx_on_block
[
I3
]));
// register number
c_m10_m11_n10_n11_thread_origin_idx_on_block
[
I1
],
},
in0
,
Number
<
NumDTensor
>
{});
c_m10_m11_n10_n11_thread_origin_idx_on_block
[
I2
],
c_m10_m11_n10_n11_thread_origin_idx_on_block
[
I3
]));
static_for
<
0
,
c_m10_m11_n10_n11_thread_tensor_lengths
[
I0
],
1
>
{}([
&
](
auto
m10
)
{
},
static_for
<
0
,
c_m10_m11_n10_n11_thread_tensor_lengths
[
I1
],
1
>
{}([
&
](
auto
m11
)
{
Number
<
NumDTensor
>
{});
static_for
<
0
,
c_m10_m11_n10_n11_thread_tensor_lengths
[
I2
],
1
>
{}([
&
](
auto
n10
)
{
ignore
=
m10
;
static_for
<
0
,
c_m10_m11_n10_n11_thread_tensor_lengths
[
I0
],
1
>
{}([
&
](
auto
m10
)
{
ignore
=
m11
;
static_for
<
0
,
c_m10_m11_n10_n11_thread_tensor_lengths
[
I1
],
1
>
{}([
&
](
auto
m11
)
{
ignore
=
n10
;
static_for
<
0
,
c_m10_m11_n10_n11_thread_tensor_lengths
[
I2
],
1
>
{}(
ignore
=
ds_thread_buf
;
[
&
](
auto
n10
)
{
ignore
=
ds_threadwise_copy
;
// load d matrix data
ignore
=
ds_grid_buf
;
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
ds_threadwise_copy
(
i
).
Run
(
ds_grid_desc_m0_m10_m11_n0_n10_n11
[
i
],
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
ds_grid_buf
[
i
],
ds_threadwise_copy
(
i
).
Run
(
ds_grid_desc_m0_m10_m11_n0_n10_n11
[
i
],
c_thread_desc_m0_m10_m11_n0_n10_n11
,
ds_grid_buf
[
i
],
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
c_thread_desc_m0_m10_m11_n0_n10_n11
,
ds_thread_buf
(
i
));
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
});
ds_thread_buf
(
i
));
// cal element op
});
static_for
<
0
,
c_m10_m11_n10_n11_thread_tensor_lengths
[
I3
],
1
>
{}(
[
&
](
auto
i
)
{
static_for
<
0
,
c_m10_m11_n10_n11_thread_tensor_lengths
[
I3
],
1
>
{}(
// get reference to src data
[
&
](
auto
i
)
{
const
auto
src_data_refs
=
generate_tie
(
// get reference to src data
// return type should be lvalue
const
auto
src_data_refs
=
generate_tie
(
[
&
](
auto
iSrc
)
->
const
auto
&
{
// return type should be lvalue
return
ds_thread_buf
[
iSrc
][
i
];
[
&
](
auto
iSrc
)
->
const
auto
&
{
},
return
ds_thread_buf
[
iSrc
][
i
];
Number
<
NumDTensor
>
{});
},
Number
<
NumDTensor
>
{});
// get reference to dst data
constexpr
index_t
c_offset
=
// get reference to dst data
c_thread_desc_m0_m10_m11_n0_n10_n11
.
CalculateOffset
(
constexpr
index_t
c_offset
=
make_tuple
(
0
,
m10
,
m11
,
0
,
n10
,
i
));
c_thread_desc_m0_m10_m11_n0_n10_n11
.
CalculateOffset
(
auto
dst_data_refs
=
generate_tie
(
make_tuple
(
0
,
m10
,
m11
,
0
,
n10
,
i
));
// return type should be lvalue
auto
dst_data_refs
=
generate_tie
(
[
&
](
auto
)
->
auto
&
{
// return type should be lvalue
return
c_thread_buf
(
Number
<
c_offset
>
{});
[
&
](
auto
)
->
auto
&
{
return
c_thread_buf
(
Number
<
c_offset
>
{});
},
},
Number
<
2
>
{});
Number
<
2
>
{});
unpack2
(
cde_element_op
,
dst_data_refs
,
src_data_refs
);
unpack2
(
cde_element_op
,
dst_data_refs
,
src_data_refs
);
});
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
ds_threadwise_copy
(
i
).
MoveSrcSliceWindow
(
ds_grid_desc_m0_m10_m11_n0_n10_n11
[
i
],
make_multi_index
(
0
,
0
,
0
,
0
,
1
,
0
));
});
});
});
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
ds_threadwise_copy
(
i
).
MoveSrcSliceWindow
(
ds_threadwise_copy
(
i
).
MoveSrcSliceWindow
(
ds_grid_desc_m0_m10_m11_n0_n10_n11
[
i
],
ds_grid_desc_m0_m10_m11_n0_n10_n11
[
i
],
make_multi_index
(
0
,
0
,
0
,
0
,
1
,
0
));
make_multi_index
(
0
,
0
,
1
,
0
,
-
c_m10_m11_n10_n11_thread_tensor_lengths
[
I2
],
0
));
});
});
});
});
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
ds_threadwise_copy
(
i
).
MoveSrcSliceWindow
(
ds_threadwise_copy
(
i
).
MoveSrcSliceWindow
(
ds_grid_desc_m0_m10_m11_n0_n10_n11
[
i
],
ds_grid_desc_m0_m10_m11_n0_n10_n11
[
i
],
make_multi_index
(
make_multi_index
(
0
,
0
,
1
,
0
,
-
c_m10_m11_n10_n11_thread_tensor_lengths
[
I
2
]
,
0
));
0
,
1
,
-
c_m10_m11_n10_n11_thread_tensor_lengths
[
I
1
],
0
,
0
,
0
));
});
});
});
});
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
}
ds_threadwise_copy
(
i
).
MoveSrcSliceWindow
(
ds_grid_desc_m0_m10_m11_n0_n10_n11
[
i
],
make_multi_index
(
0
,
1
,
-
c_m10_m11_n10_n11_thread_tensor_lengths
[
I1
],
0
,
0
,
0
));
});
});
ThreadwiseTensorSliceTransfer_v1r3
<
ThreadwiseTensorSliceTransfer_v1r3
<
FloatAcc
,
FloatAcc
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment