Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
yangql
composable_kernel-1
Commits
54c90aae
Unverified
Commit
54c90aae
authored
Apr 27, 2023
by
Haocong WANG
Committed by
GitHub
Apr 26, 2023
Browse files
add vector load check (#680)
Co-authored-by:
zjing14
<
zhangjing14@gmail.com
>
parent
7613c1d9
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
158 additions
and
2 deletions
+158
-2
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp
.../gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp
+88
-1
include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp
.../ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp
+70
-1
No files found.
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp
View file @
54c90aae
...
@@ -273,7 +273,10 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
...
@@ -273,7 +273,10 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
N01_
{
N01
},
N01_
{
N01
},
a_element_op_
{
a_element_op
},
a_element_op_
{
a_element_op
},
b_element_op_
{
b_element_op
},
b_element_op_
{
b_element_op
},
cde_element_op_
{
cde_element_op
}
cde_element_op_
{
cde_element_op
},
MRaw_
{
M
},
NRaw_
{
N
},
KRaw_
{
K
}
{
{
a_grid_desc_k0_m_k1_
=
DeviceOp
::
MakeAGridDescriptor_K0_M_K1
(
M
,
K
,
StrideA
);
a_grid_desc_k0_m_k1_
=
DeviceOp
::
MakeAGridDescriptor_K0_M_K1
(
M
,
K
,
StrideA
);
b_grid_desc_k0_n_k1_
=
DeviceOp
::
MakeBGridDescriptor_K0_N_K1
(
K
,
N
,
StrideB
);
b_grid_desc_k0_n_k1_
=
DeviceOp
::
MakeBGridDescriptor_K0_N_K1
(
K
,
N
,
StrideB
);
...
@@ -335,6 +338,11 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
...
@@ -335,6 +338,11 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
AElementwiseOperation
a_element_op_
;
AElementwiseOperation
a_element_op_
;
BElementwiseOperation
b_element_op_
;
BElementwiseOperation
b_element_op_
;
CDEElementwiseOperation
cde_element_op_
;
CDEElementwiseOperation
cde_element_op_
;
// for checking vector load/store
index_t
MRaw_
;
index_t
NRaw_
;
index_t
KRaw_
;
};
};
// Invoker
// Invoker
...
@@ -488,6 +496,85 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
...
@@ -488,6 +496,85 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
{
{
return
false
;
return
false
;
}
}
// check vector load/store
{
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
// check vector load of A
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
ABlockTransferSrcVectorDim
==
2
)
{
if
(
arg
.
KRaw_
%
ABlockTransferSrcScalarPerVector
!=
0
)
{
return
false
;
}
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Col
>
&&
ABlockTransferSrcVectorDim
==
1
)
{
// FIXME: not rigorous
if
(
arg
.
MRaw_
%
ABlockTransferSrcScalarPerVector
!=
0
)
{
return
false
;
}
}
else
{
return
false
;
}
// check vector laod of B
if
constexpr
(
is_same_v
<
BLayout
,
Col
>
&&
BBlockTransferSrcVectorDim
==
2
)
{
if
(
arg
.
KRaw_
%
BBlockTransferSrcScalarPerVector
!=
0
)
{
return
false
;
}
}
else
if
constexpr
(
is_same_v
<
BLayout
,
Row
>
&&
BBlockTransferSrcVectorDim
==
1
)
{
// FIXME: not rigorous
if
(
arg
.
NRaw_
%
BBlockTransferSrcScalarPerVector
!=
0
)
{
return
false
;
}
}
else
{
return
false
;
}
// check vector load of Ds
// only support RowMajor for now
bool
all_valid
=
true
;
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
using
DLayout
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsLayout
>>
;
if
constexpr
(
!
is_same_v
<
DLayout
,
Row
>
)
{
all_valid
=
false
;
}
});
if
(
!
all_valid
)
{
return
false
;
}
// check vector store of E
// only support RowMajor for now
if
constexpr
(
is_same_v
<
ELayout
,
Row
>
)
{
if
(
arg
.
NRaw_
%
CDEShuffleBlockTransferScalarPerVector_NPerBlock
!=
0
)
{
return
false
;
}
}
else
{
return
false
;
}
}
return
GridwiseOp
::
CheckValidity
(
arg
.
a_grid_desc_k0_m_k1_
,
return
GridwiseOp
::
CheckValidity
(
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp
View file @
54c90aae
...
@@ -239,7 +239,10 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
...
@@ -239,7 +239,10 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
N01_
{
N01
},
N01_
{
N01
},
a_element_op_
{
a_element_op
},
a_element_op_
{
a_element_op
},
b_element_op_
{
b_element_op
},
b_element_op_
{
b_element_op
},
c_element_op_
{
c_element_op
}
c_element_op_
{
c_element_op
},
MRaw_
{
M
},
NRaw_
{
N
},
KRaw_
{
K
}
{
{
a_grid_desc_k0_m_k1_
=
a_grid_desc_k0_m_k1_
=
DeviceGemmWmma_CShuffle
::
MakeAGridDescriptor_K0_M_K1
(
M
,
K
,
StrideA
);
DeviceGemmWmma_CShuffle
::
MakeAGridDescriptor_K0_M_K1
(
M
,
K
,
StrideA
);
...
@@ -276,6 +279,10 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
...
@@ -276,6 +279,10 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
AElementwiseOperation
a_element_op_
;
AElementwiseOperation
a_element_op_
;
BElementwiseOperation
b_element_op_
;
BElementwiseOperation
b_element_op_
;
CElementwiseOperation
c_element_op_
;
CElementwiseOperation
c_element_op_
;
// for checking vector load/store
index_t
MRaw_
;
index_t
NRaw_
;
index_t
KRaw_
;
};
};
// Invoker
// Invoker
...
@@ -417,6 +424,68 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
...
@@ -417,6 +424,68 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
return
false
;
return
false
;
}
}
// check vector load/store
{
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
// check vector load of A
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
ABlockTransferSrcVectorDim
==
2
)
{
if
(
arg
.
KRaw_
%
ABlockTransferSrcScalarPerVector
!=
0
)
{
return
false
;
}
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Col
>
&&
ABlockTransferSrcVectorDim
==
1
)
{
// FIXME: not rigorous
if
(
arg
.
MRaw_
%
ABlockTransferSrcScalarPerVector
!=
0
)
{
return
false
;
}
}
else
{
return
false
;
}
// check vector laod of B
if
constexpr
(
is_same_v
<
BLayout
,
Col
>
&&
BBlockTransferSrcVectorDim
==
2
)
{
if
(
arg
.
KRaw_
%
BBlockTransferSrcScalarPerVector
!=
0
)
{
return
false
;
}
}
else
if
constexpr
(
is_same_v
<
BLayout
,
Row
>
&&
BBlockTransferSrcVectorDim
==
1
)
{
// FIXME: not rigorous
if
(
arg
.
NRaw_
%
BBlockTransferSrcScalarPerVector
!=
0
)
{
return
false
;
}
}
else
{
return
false
;
}
// check vector store of C
// only support RowMajor for now
if
constexpr
(
is_same_v
<
CLayout
,
Row
>
)
{
if
(
arg
.
NRaw_
%
CShuffleBlockTransferScalarPerVector_NPerBlock
!=
0
)
{
return
false
;
}
}
else
{
return
false
;
}
}
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_k0_m_k1_
,
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
c_grid_desc_m_n_
,
arg
.
c_grid_desc_m_n_
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment