Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
dccaf0b2
Commit
dccaf0b2
authored
Aug 10, 2023
by
letaoqin
Browse files
change check name to d0s_n_length_stride_
parent
514cee8a
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
14 additions
and
16 deletions
+14
-16
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_fwd_xdl_cshuffle_v2.hpp
...pu/device/impl/device_batched_mha_fwd_xdl_cshuffle_v2.hpp
+6
-6
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_fwd_xdl_cshuffle_v2.hpp
...pu/device/impl/device_grouped_mha_fwd_xdl_cshuffle_v2.hpp
+8
-10
No files found.
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_fwd_xdl_cshuffle_v2.hpp
View file @
dccaf0b2
...
@@ -721,9 +721,9 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
...
@@ -721,9 +721,9 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
// D0 pointer
// D0 pointer
p_d0s_grid_
(
i
)
=
static_cast
<
const
D0DataType
*>
(
p_acc0_biases
[
i
]);
p_d0s_grid_
(
i
)
=
static_cast
<
const
D0DataType
*>
(
p_acc0_biases
[
i
]);
// for check
// for check
d0s_n
l_ns
_length
s
_stride
s
_
[
i
].
push_back
(
d0s_n_length_stride_
[
i
].
push_back
(
acc0_biases_gs_ms_ns_lengths
[
i
][
NumDimG
+
NumDimM
]);
acc0_biases_gs_ms_ns_lengths
[
i
][
NumDimG
+
NumDimM
]);
d0s_n
l_ns
_length
s
_stride
s
_
[
i
].
push_back
(
d0s_n_length_stride_
[
i
].
push_back
(
acc0_biases_gs_ms_ns_strides
[
i
][
NumDimG
+
NumDimM
]);
acc0_biases_gs_ms_ns_strides
[
i
][
NumDimG
+
NumDimM
]);
});
});
is_dropout_
=
p_dropout
>
0.0
;
//
is_dropout_
=
p_dropout
>
0.0
;
//
...
@@ -830,7 +830,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
...
@@ -830,7 +830,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
index_t
n_raw_padded_
;
index_t
n_raw_padded_
;
// raw data
// raw data
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumD0Tensor
>
d0s_n
l_ns
_length
s
_stride
s
_
;
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumD0Tensor
>
d0s_n_length_stride_
;
};
};
// Invoker
// Invoker
...
@@ -1039,12 +1039,12 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
...
@@ -1039,12 +1039,12 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
for
(
int
i
=
0
;
i
<
NumD0Tensor
;
i
++
)
for
(
int
i
=
0
;
i
<
NumD0Tensor
;
i
++
)
{
{
if
(
arg
.
d0s_n
l_ns
_length
s
_stride
s
_
[
i
][
1
]
==
1
&&
if
(
arg
.
d0s_n_length_stride_
[
i
][
1
]
==
1
&&
arg
.
d0s_n
l_ns
_length
s
_stride
s
_
[
i
][
0
]
%
Acc0BiasTransferSrcScalarPerVector
!=
0
)
arg
.
d0s_n_length_stride_
[
i
][
0
]
%
Acc0BiasTransferSrcScalarPerVector
!=
0
)
{
{
return
false
;
return
false
;
}
}
if
(
arg
.
d0s_n
l_ns
_length
s
_stride
s
_
[
i
][
1
]
!=
1
&&
Acc0BiasTransferSrcScalarPerVector
!=
1
)
if
(
arg
.
d0s_n_length_stride_
[
i
][
1
]
!=
1
&&
Acc0BiasTransferSrcScalarPerVector
!=
1
)
{
{
return
false
;
return
false
;
}
}
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_fwd_xdl_cshuffle_v2.hpp
View file @
dccaf0b2
...
@@ -658,7 +658,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
...
@@ -658,7 +658,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
CGridDesc_M_N
c_grid_desc_m_n_
;
CGridDesc_M_N
c_grid_desc_m_n_
;
// raw data
// raw data
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumD0Tensor
>
d0s_n
l_ns
_length
s
_stride
s
_
;
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumD0Tensor
>
d0s_n_length_stride_
;
};
};
// Argument
// Argument
...
@@ -708,16 +708,16 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
...
@@ -708,16 +708,16 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
const
auto
p_b_grid
=
static_cast
<
const
BDataType
*>
(
p_b_vec
[
i
]);
const
auto
p_b_grid
=
static_cast
<
const
BDataType
*>
(
p_b_vec
[
i
]);
const
auto
&
problem_desc
=
problem_desc_vec
[
i
];
const
auto
&
problem_desc
=
problem_desc_vec
[
i
];
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumD0Tensor
>
d0s_n
l_ns
_length
s
_stride
s
;
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumD0Tensor
>
d0s_n_length_stride
;
typename
GridwiseGemm
::
D0sGridPointer
p_d0s_grid
;
typename
GridwiseGemm
::
D0sGridPointer
p_d0s_grid
;
static_for
<
0
,
NumD0Tensor
,
1
>
{}([
&
](
auto
j
)
{
static_for
<
0
,
NumD0Tensor
,
1
>
{}([
&
](
auto
j
)
{
using
D0DataType
=
remove_cvref_t
<
tuple_element_t
<
j
.
value
,
Acc0BiasDataType
>>
;
using
D0DataType
=
remove_cvref_t
<
tuple_element_t
<
j
.
value
,
Acc0BiasDataType
>>
;
// D0 pointer
// D0 pointer
p_d0s_grid
(
j
)
=
static_cast
<
const
D0DataType
*>
(
p_acc0_biases_vec
[
i
][
j
]);
p_d0s_grid
(
j
)
=
static_cast
<
const
D0DataType
*>
(
p_acc0_biases_vec
[
i
][
j
]);
// for check
// for check
d0s_n
l_ns
_length
s
_stride
s
[
j
].
push_back
(
d0s_n_length_stride
[
j
].
push_back
(
problem_desc
.
acc0_biases_gs_ms_ns_lengths
[
j
][
NumDimG
+
NumDimM
]);
problem_desc
.
acc0_biases_gs_ms_ns_lengths
[
j
][
NumDimG
+
NumDimM
]);
d0s_n
l_ns
_length
s
_stride
s
[
j
].
push_back
(
d0s_n_length_stride
[
j
].
push_back
(
problem_desc
.
acc0_biases_gs_ms_ns_strides
[
j
][
NumDimG
+
NumDimM
]);
problem_desc
.
acc0_biases_gs_ms_ns_strides
[
j
][
NumDimG
+
NumDimM
]);
});
});
...
@@ -859,7 +859,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
...
@@ -859,7 +859,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
{
problem_desc
.
c_gs_ms_os_strides
[
NumDimG
+
NumDimM
-
1
],
{
problem_desc
.
c_gs_ms_os_strides
[
NumDimG
+
NumDimM
-
1
],
problem_desc
.
c_gs_ms_os_strides
[
NumDimG
+
NumDimM
+
NumDimO
-
1
]},
problem_desc
.
c_gs_ms_os_strides
[
NumDimG
+
NumDimM
+
NumDimO
-
1
]},
c_grid_desc_m_n
,
c_grid_desc_m_n
,
d0s_n
l_ns
_length
s
_stride
s
});
d0s_n_length_stride
});
}
}
is_dropout_
=
p_dropout
>
0.0
;
//
is_dropout_
=
p_dropout
>
0.0
;
//
...
@@ -1081,14 +1081,12 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
...
@@ -1081,14 +1081,12 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
for
(
int
In
=
0
;
In
<
NumD0Tensor
;
In
++
)
for
(
int
In
=
0
;
In
<
NumD0Tensor
;
In
++
)
{
{
if
(
device_arg
.
d0s_nl_ns_lengths_strides_
[
In
][
1
]
==
1
&&
if
(
device_arg
.
d0s_n_length_stride_
[
In
][
1
]
==
1
&&
device_arg
.
d0s_nl_ns_lengths_strides_
[
In
][
0
]
%
device_arg
.
d0s_n_length_stride_
[
In
][
0
]
%
Acc0BiasTransferSrcScalarPerVector
!=
0
)
Acc0BiasTransferSrcScalarPerVector
!=
0
)
{
{
return
false
;
return
false
;
}
}
if
(
device_arg
.
d0s_n
l_ns
_length
s
_stride
s
_
[
In
][
1
]
!=
1
&&
if
(
device_arg
.
d0s_n_length_stride_
[
In
][
1
]
!=
1
&&
Acc0BiasTransferSrcScalarPerVector
!=
1
)
Acc0BiasTransferSrcScalarPerVector
!=
1
)
{
{
return
false
;
return
false
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment