Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
8ea2e1c9
".github/git@developer.sourcefind.cn:OpenDAS/fairscale.git" did not exist on "684e6aededb43a7976416d1ba64cb8e8034c1cf3"
Commit
8ea2e1c9
authored
Dec 03, 2023
by
Bartlomiej Wroblewski
Browse files
Fix the IsSupported check in contraction op
parent
8ff845f2
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
132 additions
and
69 deletions
+132
-69
include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp
...evice/impl/device_contraction_multiple_d_xdl_cshuffle.hpp
+59
-69
include/ck/tensor_operation/gpu/device/impl/device_contraction_utils.hpp
...or_operation/gpu/device/impl/device_contraction_utils.hpp
+73
-0
No files found.
include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp
View file @
8ea2e1c9
...
@@ -11,6 +11,7 @@
...
@@ -11,6 +11,7 @@
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp"
#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_contraction_utils.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
...
@@ -411,13 +412,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
...
@@ -411,13 +412,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
block_2_etile_map_
{
GridwiseGemm
::
MakeDefaultBlock2ETileMap
(
e_grid_desc_m_n_
)},
block_2_etile_map_
{
GridwiseGemm
::
MakeDefaultBlock2ETileMap
(
e_grid_desc_m_n_
)},
a_element_op_
{
a_element_op
},
a_element_op_
{
a_element_op
},
b_element_op_
{
b_element_op
},
b_element_op_
{
b_element_op
},
cde_element_op_
{
cde_element_op
},
cde_element_op_
{
cde_element_op
}
a_mz_stride_
{},
a_kz_stride_
{},
b_nz_stride_
{},
b_kz_stride_
{},
ds_nz_stride_
{},
e_nz_stride_
{}
{
{
// populate pointer, batch stride, desc for Ds
// populate pointer, batch stride, desc for Ds
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
...
@@ -448,18 +443,27 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
...
@@ -448,18 +443,27 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
}
}
// for sanity check of vector memory access
// for sanity check of vector memory access
a_mz_stride_
=
a_ms_ks_strides
[
NumDimM
-
1
];
a_mz_consecutive_
=
a_ms_ks_strides
[
NumDimM
-
1
]
==
1
;
a_kz_stride_
=
a_ms_ks_strides
[
NumDimM
+
NumDimK
-
1
];
a_kz_consecutive_
=
a_ms_ks_strides
[
NumDimM
+
NumDimK
-
1
]
==
1
;
b_nz_consecutive_
=
b_ns_ks_strides
[
NumDimN
-
1
]
==
1
;
b_nz_stride_
=
b_ns_ks_strides
[
NumDimN
-
1
];
b_kz_consecutive_
=
b_ns_ks_strides
[
NumDimN
+
NumDimK
-
1
]
==
1
;
b_kz_stride_
=
b_ns_ks_strides
[
NumDimN
+
NumDimK
-
1
];
for
(
index_t
i
=
0
;
i
<
NumDTensor
;
++
i
)
for
(
index_t
i
=
0
;
i
<
NumDTensor
;
++
i
)
{
{
ds_nz_
strid
e_
[
i
]
=
ds_ms_ns_strides
[
i
][
NumDimM
+
NumDimN
-
1
];
ds_nz_
consecutiv
e_
[
i
]
=
ds_ms_ns_strides
[
i
][
NumDimM
+
NumDimN
-
1
]
==
1
;
}
}
e_nz_consecutive_
=
e_ms_ns_strides
[
NumDimM
+
NumDimN
-
1
]
==
1
;
e_nz_stride_
=
e_ms_ns_strides
[
NumDimM
+
NumDimN
-
1
];
a_max_read_elems_
=
CalculateMaxRead
<
NumDimM
,
NumDimK
>
(
a_ms_ns_lengths
,
a_ms_ks_strides
);
b_max_read_elems_
=
CalculateMaxRead
<
NumDimN
,
NumDimK
>
(
b_ns_ks_lengths
,
b_ns_ks_strides
);
for
(
index_t
i
=
0
;
i
<
NumDTensor
;
++
i
)
{
ds_max_read_elems_
[
i
]
=
CalculateMaxRead
<
NumDimM
,
NumDimK
>
(
ds_ms_ns_lengths
[
i
],
ds_ms_ns_strides
[
i
]);
}
e_max_write_elems_
=
CalculateMaxRead
<
NumDimM
,
NumDimK
>
(
e_ms_ns_lengths
,
e_ms_ns_strides
);
}
}
void
Print
()
const
void
Print
()
const
...
@@ -499,15 +503,19 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
...
@@ -499,15 +503,19 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
BElementwiseOperation
b_element_op_
;
BElementwiseOperation
b_element_op_
;
CDEElementwiseOperation
cde_element_op_
;
CDEElementwiseOperation
cde_element_op_
;
// Strides for the last M/N/K dimensions of A/B/Ds/E
// Describe whether the last part of a given dimension of A/B/D/E is consecutive
// for sanity check of vector load/store
// in the memory or not.
index_t
a_mz_stride_
;
bool
a_mz_consecutive_
;
index_t
a_kz_stride_
;
bool
a_kz_consecutive_
;
index_t
b_nz_stride_
;
bool
b_nz_consecutive_
;
index_t
b_kz_stride_
;
bool
b_kz_consecutive_
;
std
::
array
<
index_t
,
NumDTensor
>
ds_nz_stride_
;
std
::
array
<
bool
,
NumDTensor
>
ds_nz_consecutive_
;
index_t
e_mz_stride_
;
bool
e_nz_consecutive_
;
index_t
e_nz_stride_
;
index_t
a_max_read_elems_
;
index_t
b_max_read_elems_
;
std
::
array
<
index_t
,
NumDTensor
>
ds_max_read_elems_
;
index_t
e_max_write_elems_
;
};
};
// Invoker
// Invoker
...
@@ -616,65 +624,47 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
...
@@ -616,65 +624,47 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
(
BBlockTransferSrcVectorDim
==
1
||
BBlockTransferSrcVectorDim
==
2
),
(
BBlockTransferSrcVectorDim
==
1
||
BBlockTransferSrcVectorDim
==
2
),
"wrong!"
);
"wrong!"
);
// vector memory access of A: could be on M or AK1 dimension
const
bool
valid_a_vector_size
=
if
constexpr
(
ABlockTransferSrcVectorDim
==
1
)
arg
.
a_max_read_elems_
%
ABlockTransferSrcScalarPerVector
==
0
;
const
bool
valid_a_access_dim_m
=
ABlockTransferSrcVectorDim
==
1
&&
arg
.
a_mz_consecutive_
;
const
bool
valid_a_access_dim_k
=
ABlockTransferSrcVectorDim
==
2
&&
arg
.
a_kz_consecutive_
;
const
bool
valid_a_access_dim
=
valid_a_access_dim_m
||
valid_a_access_dim_k
;
if
(
!
(
valid_a_vector_size
&&
valid_a_access_dim
))
{
{
if
(
!
(
arg
.
a_mz_stride_
==
1
&&
return
false
;
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I1
)
%
ABlockTransferSrcScalarPerVector
==
0
))
{
return
false
;
}
}
else
{
if
(
!
(
arg
.
a_kz_stride_
==
1
&&
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I2
)
%
ABlockTransferSrcScalarPerVector
==
0
))
{
return
false
;
}
}
}
// vector memory access of B: could be on N or BK1 dimension
const
bool
valid_b_vector_size
=
if
constexpr
(
BBlockTransferSrcVectorDim
==
1
)
arg
.
b_max_read_elems_
%
BBlockTransferSrcScalarPerVector
==
0
;
{
const
bool
valid_b_access_dim_n
=
BBlockTransferSrcVectorDim
==
1
&&
arg
.
b_nz_consecutive_
;
if
(
!
(
arg
.
b_nz_stride_
==
1
&&
const
bool
valid_b_access_dim_k
=
BBlockTransferSrcVectorDim
==
2
&&
arg
.
b_kz_consecutive_
;
arg
.
b_grid_desc_bk0_n_bk1_
.
GetLength
(
I1
)
%
BBlockTransferSrcScalarPerVector
==
0
))
const
bool
valid_b_access_dim
=
valid_b_access_dim_n
||
valid_b_access_dim_k
;
{
if
(
!
(
valid_b_vector_size
&&
valid_b_access_dim
))
return
false
;
}
}
else
{
{
if
(
!
(
arg
.
b_kz_stride_
==
1
&&
return
false
;
arg
.
b_grid_desc_bk0_n_bk1_
.
GetLength
(
I2
)
%
BBlockTransferSrcScalarPerVector
==
0
))
{
return
false
;
}
}
}
// vector memory access of Ds: always on NPerBlock dimension
bool
valid_ds_access
=
true
;
bool
valid_d_access
=
true
;
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
if
(
!
(
arg
.
ds_nz_stride_
[
i
]
==
1
&&
const
bool
valid_d_vector_size
=
arg
.
ds_grid_desc_mblock_mperblock_nblock_nperblock_
[
i
].
GetLength
(
I3
)
%
arg
.
ds_max_read_elems_
[
i
]
%
CDEBlockTransferScalarPerVector_NPerBlock
==
0
;
CDEBlockTransferScalarPerVector_NPerBlock
==
// Vector read of Ds is always on N dimension.
0
))
const
bool
valid_d_access_dim
=
arg
.
ds_nz_consecutive_
[
i
];
if
(
!
(
valid_d_vector_size
&&
valid_d_access_dim
))
{
{
valid_d_access
=
false
;
valid_d
s
_access
=
false
;
}
}
});
});
if
(
valid_ds_access
==
false
)
if
(
valid_d_access
==
false
)
{
{
return
false
;
return
false
;
}
}
// vector memory access of E: always on NPerBlock dimension
const
bool
valid_e_vector_size
=
if
(
!
(
arg
.
e_
nz_stride_
==
1
&&
arg
.
e_
max_write_elems_
%
CDEBlockTransferScalarPerVector_NPerBlock
==
0
;
arg
.
e_grid_desc_mblock_mperblock_nblock_nperblock_
.
GetLength
(
I3
)
%
// Vector write of E is always on N dimension.
CDEBlockTransferScalarPerVector_NPerBlock
==
const
bool
valid_e_access_dim
=
arg
.
e_nz_consecutive_
;
0
))
if
(
!
(
valid_e_vector_size
&&
valid_e_access_dim
))
{
{
return
false
;
return
false
;
}
}
...
...
include/ck/tensor_operation/gpu/device/impl/device_contraction_utils.hpp
0 → 100644
View file @
8ea2e1c9
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cassert>
#include <vector>
#include "ck/ck.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
/**
* Calculates the maximum number of subsequent elements of the fast changing dimension
* that are consecutive in memory.
*
* Example:
* NumDimM = 2, NumDimK = 3
* A shape = [ 2, 3, 4, 5, 6]
* A strides = [360, 120, 30, 6, 1]
* | M | | K |
* It follows from strides that K is FCD and all the subsequent elements of K are consecutive
* in memory.
* But if strides were [360, 120, 6, 24, 1], then only 6 subsequent elements of K would be
* consecutive in memory.
*
* Assumes that the dimensions are split into two groups of `NumDim1` and `NumDim2` dimensions.
*/
template
<
index_t
NumDim1
,
index_t
NumDim2
>
auto
CalculateMaxRead
(
const
std
::
vector
<
index_t
>&
lengths
,
const
std
::
vector
<
index_t
>&
strides
)
{
assert
(
lengths
.
size
()
==
NumDim1
+
NumDim2
&&
strides
.
size
()
==
NumDim1
+
NumDim2
);
// Determine the beginning and end idx of the group representing the FCD.
index_t
begin_idx
,
end_idx
;
if
(
strides
[
NumDim1
-
1
]
==
1
)
{
begin_idx
=
0
;
end_idx
=
NumDim1
-
1
;
}
else
if
(
strides
[
NumDim1
+
NumDim2
-
1
]
==
1
)
{
begin_idx
=
NumDim1
;
end_idx
=
NumDim1
+
NumDim2
-
1
;
}
else
{
// The dimension consecutive in memory is not the last dimension of any group, so only
// one element can be read/written at once.
return
1
;
}
index_t
consecutive_stride
=
1
;
for
(
index_t
dim_idx
=
end_idx
;
dim_idx
>=
begin_idx
;
--
dim_idx
)
{
if
(
strides
[
dim_idx
]
==
consecutive_stride
)
{
consecutive_stride
*=
lengths
[
dim_idx
];
}
else
{
break
;
}
}
const
index_t
max_subsequent_elems
=
consecutive_stride
;
return
max_subsequent_elems
;
}
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment