Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
9574b34d
Commit
9574b34d
authored
Oct 09, 2023
by
danyao12
Browse files
adjust grouped kernels interface
parent
29398e70
Changes
17
Hide whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
31 additions
and
66 deletions
+31
-66
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_backward_v2.cpp
..._softmax_gemm/grouped_multihead_attention_backward_v2.cpp
+0
-2
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_backward_v3.cpp
..._softmax_gemm/grouped_multihead_attention_backward_v3.cpp
+0
-2
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_train_v2.cpp
...ale_softmax_gemm/grouped_multihead_attention_train_v2.cpp
+2
-6
example/32_batched_gemm_scale_softmax_gemm/run_grouped_multihead_attention_forward.inc
..._softmax_gemm/run_grouped_multihead_attention_forward.inc
+2
-4
example/52_flash_atten_bias/grouped_multihead_attention_bias_backward_v2.cpp
...ten_bias/grouped_multihead_attention_bias_backward_v2.cpp
+0
-2
example/52_flash_atten_bias/run_grouped_multihead_attention_bias_forward_v2.inc
..._bias/run_grouped_multihead_attention_bias_forward_v2.inc
+2
-4
include/ck/tensor_operation/gpu/device/device_grouped_gemm_softmax_gemm_permute.hpp
...n/gpu/device/device_grouped_gemm_softmax_gemm_permute.hpp
+0
-1
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_light_v1.hpp
...pl/device_batched_mha_bwd_xdl_cshuffle_qloop_light_v1.hpp
+1
-2
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_light_v2.hpp
...pl/device_batched_mha_bwd_xdl_cshuffle_qloop_light_v2.hpp
+1
-2
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_v1.hpp
...ice/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_v1.hpp
+1
-2
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_v2.hpp
...ice/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_v2.hpp
+1
-2
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_fwd_xdl_cshuffle_v2.hpp
...pu/device/impl/device_batched_mha_fwd_xdl_cshuffle_v2.hpp
+1
-2
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_light_v1.hpp
...pl/device_grouped_mha_bwd_xdl_cshuffle_qloop_light_v1.hpp
+4
-7
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_light_v2.hpp
...pl/device_grouped_mha_bwd_xdl_cshuffle_qloop_light_v2.hpp
+4
-7
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_v1.hpp
...ice/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_v1.hpp
+4
-7
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_v2.hpp
...ice/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_v2.hpp
+4
-7
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_fwd_xdl_cshuffle_v2.hpp
...pu/device/impl/device_grouped_mha_fwd_xdl_cshuffle_v2.hpp
+4
-7
No files found.
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_backward_v2.cpp
View file @
9574b34d
...
@@ -645,7 +645,6 @@ int run(int argc, char* argv[])
...
@@ -645,7 +645,6 @@ int run(int argc, char* argv[])
QKVElementOp
{},
QKVElementOp
{},
YElementOp
{},
YElementOp
{},
p_drop
,
p_drop
,
h_ratio
,
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
(
seed
,
offset
));
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
(
seed
,
offset
));
DeviceMem
problem_desc_workspace
(
gemm
.
GetWorkSpaceSize
(
&
argument
));
DeviceMem
problem_desc_workspace
(
gemm
.
GetWorkSpaceSize
(
&
argument
));
...
@@ -694,7 +693,6 @@ int run(int argc, char* argv[])
...
@@ -694,7 +693,6 @@ int run(int argc, char* argv[])
QKVElementOp
{},
QKVElementOp
{},
YElementOp
{},
YElementOp
{},
p_drop
,
p_drop
,
h_ratio
,
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
(
seed
,
offset
));
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
(
seed
,
offset
));
DeviceMem
problem_desc_workspace_verify
(
gemm
.
GetWorkSpaceSize
(
&
argument
));
DeviceMem
problem_desc_workspace_verify
(
gemm
.
GetWorkSpaceSize
(
&
argument
));
gemm
.
SetWorkSpacePointer
(
&
argument
,
problem_desc_workspace_verify
.
GetDeviceBuffer
());
gemm
.
SetWorkSpacePointer
(
&
argument
,
problem_desc_workspace_verify
.
GetDeviceBuffer
());
...
...
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_backward_v3.cpp
View file @
9574b34d
...
@@ -657,7 +657,6 @@ int run(int argc, char* argv[])
...
@@ -657,7 +657,6 @@ int run(int argc, char* argv[])
QKVElementOp
{},
QKVElementOp
{},
YElementOp
{},
YElementOp
{},
p_drop
,
p_drop
,
h_ratio
,
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
(
seed
,
offset
));
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
(
seed
,
offset
));
DeviceMem
problem_desc_workspace
(
gemm
.
GetWorkSpaceSize
(
&
argument
));
DeviceMem
problem_desc_workspace
(
gemm
.
GetWorkSpaceSize
(
&
argument
));
...
@@ -707,7 +706,6 @@ int run(int argc, char* argv[])
...
@@ -707,7 +706,6 @@ int run(int argc, char* argv[])
QKVElementOp
{},
QKVElementOp
{},
YElementOp
{},
YElementOp
{},
p_drop
,
p_drop
,
h_ratio
,
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
(
seed
,
offset
));
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
(
seed
,
offset
));
DeviceMem
problem_desc_workspace_verify
(
gemm
.
GetWorkSpaceSize
(
&
argument
));
DeviceMem
problem_desc_workspace_verify
(
gemm
.
GetWorkSpaceSize
(
&
argument
));
gemm
.
SetWorkSpacePointer
(
&
argument
,
problem_desc_workspace_verify
.
GetDeviceBuffer
());
gemm
.
SetWorkSpacePointer
(
&
argument
,
problem_desc_workspace_verify
.
GetDeviceBuffer
());
...
...
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_train_v2.cpp
View file @
9574b34d
...
@@ -721,8 +721,7 @@ int run(int argc, char* argv[])
...
@@ -721,8 +721,7 @@ int run(int argc, char* argv[])
Scale
{
alpha
},
Scale
{
alpha
},
QKVElementOp
{},
QKVElementOp
{},
YElementOp
{},
YElementOp
{},
p_drop
,
// dropout ratio
p_drop
,
// dropout ratio
h_ratio
,
{
seed
,
offset
});
// dropout random seed and offset, offset should
{
seed
,
offset
});
// dropout random seed and offset, offset should
// be at least the number of elements on a thread
// be at least the number of elements on a thread
...
@@ -770,7 +769,6 @@ int run(int argc, char* argv[])
...
@@ -770,7 +769,6 @@ int run(int argc, char* argv[])
QKVElementOp
{},
QKVElementOp
{},
YElementOp
{},
YElementOp
{},
p_drop
,
p_drop
,
h_ratio
,
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
(
seed
,
offset
));
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
(
seed
,
offset
));
DeviceMem
problem_desc_workspace_bwd
(
gemm_bwd
.
GetWorkSpaceSize
(
&
argument_bwd
));
DeviceMem
problem_desc_workspace_bwd
(
gemm_bwd
.
GetWorkSpaceSize
(
&
argument_bwd
));
...
@@ -820,8 +818,7 @@ int run(int argc, char* argv[])
...
@@ -820,8 +818,7 @@ int run(int argc, char* argv[])
Scale
{
alpha
},
Scale
{
alpha
},
QKVElementOp
{},
QKVElementOp
{},
YElementOp
{},
YElementOp
{},
p_drop
,
// dropout ratio
p_drop
,
// dropout ratio
h_ratio
,
{
seed
,
offset
});
// dropout random seed and offset, offset should
{
seed
,
offset
});
// dropout random seed and offset, offset should
// be at least the number of elements on a thread
// be at least the number of elements on a thread
...
@@ -861,7 +858,6 @@ int run(int argc, char* argv[])
...
@@ -861,7 +858,6 @@ int run(int argc, char* argv[])
QKVElementOp
{},
QKVElementOp
{},
YElementOp
{},
YElementOp
{},
p_drop
,
p_drop
,
h_ratio
,
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
(
seed
,
offset
));
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
(
seed
,
offset
));
DeviceMem
problem_desc_workspace_bwd_verify
(
gemm_bwd
.
GetWorkSpaceSize
(
&
argument_bwd
));
DeviceMem
problem_desc_workspace_bwd_verify
(
gemm_bwd
.
GetWorkSpaceSize
(
&
argument_bwd
));
gemm_bwd
.
SetWorkSpacePointer
(
&
argument_bwd
,
gemm_bwd
.
SetWorkSpacePointer
(
&
argument_bwd
,
...
...
example/32_batched_gemm_scale_softmax_gemm/run_grouped_multihead_attention_forward.inc
View file @
9574b34d
...
@@ -258,8 +258,7 @@ int run(int argc, char* argv[])
...
@@ -258,8 +258,7 @@ int run(int argc, char* argv[])
acc0_element_op
,
acc0_element_op
,
b1_element_op
,
b1_element_op
,
c_element_op
,
c_element_op
,
p_drop
,
// dropout ratio
p_drop
,
// dropout ratio
h_ratio
,
{
seed
,
offset
});
// dropout random seed and offset, offset should be
{
seed
,
offset
});
// dropout random seed and offset, offset should be
// at least the number of elements on a thread
// at least the number of elements on a thread
...
@@ -302,8 +301,7 @@ int run(int argc, char* argv[])
...
@@ -302,8 +301,7 @@ int run(int argc, char* argv[])
acc0_element_op
,
acc0_element_op
,
b1_element_op
,
b1_element_op
,
c_element_op
,
c_element_op
,
p_drop
,
// dropout ratio
p_drop
,
// dropout ratio
h_ratio
,
{
seed
,
offset
});
// dropout random seed and offset, offset should be
{
seed
,
offset
});
// dropout random seed and offset, offset should be
// at least the number of elements on a thread
// at least the number of elements on a thread
// specify workspace for problem_desc
// specify workspace for problem_desc
...
...
example/52_flash_atten_bias/grouped_multihead_attention_bias_backward_v2.cpp
View file @
9574b34d
...
@@ -684,7 +684,6 @@ int run(int argc, char* argv[])
...
@@ -684,7 +684,6 @@ int run(int argc, char* argv[])
QKVElementOp
{},
QKVElementOp
{},
YElementOp
{},
YElementOp
{},
p_drop
,
p_drop
,
h_ratio
,
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
(
seed
,
offset
));
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
(
seed
,
offset
));
DeviceMem
problem_desc_workspace
(
gemm
.
GetWorkSpaceSize
(
&
argument
));
DeviceMem
problem_desc_workspace
(
gemm
.
GetWorkSpaceSize
(
&
argument
));
...
@@ -733,7 +732,6 @@ int run(int argc, char* argv[])
...
@@ -733,7 +732,6 @@ int run(int argc, char* argv[])
QKVElementOp
{},
QKVElementOp
{},
YElementOp
{},
YElementOp
{},
p_drop
,
p_drop
,
h_ratio
,
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
(
seed
,
offset
));
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
(
seed
,
offset
));
DeviceMem
problem_desc_workspace_verify
(
gemm
.
GetWorkSpaceSize
(
&
argument
));
DeviceMem
problem_desc_workspace_verify
(
gemm
.
GetWorkSpaceSize
(
&
argument
));
gemm
.
SetWorkSpacePointer
(
&
argument
,
problem_desc_workspace_verify
.
GetDeviceBuffer
());
gemm
.
SetWorkSpacePointer
(
&
argument
,
problem_desc_workspace_verify
.
GetDeviceBuffer
());
...
...
example/52_flash_atten_bias/run_grouped_multihead_attention_bias_forward_v2.inc
View file @
9574b34d
...
@@ -280,8 +280,7 @@ int run(int argc, char* argv[])
...
@@ -280,8 +280,7 @@ int run(int argc, char* argv[])
acc0_element_op
,
acc0_element_op
,
b1_element_op
,
b1_element_op
,
c_element_op
,
c_element_op
,
p_drop
,
// dropout ratio
p_drop
,
// dropout ratio
h_ratio
,
{
seed
,
offset
});
// dropout random seed and offset, offset should be
{
seed
,
offset
});
// dropout random seed and offset, offset should be
// at least the number of elements on a thread
// at least the number of elements on a thread
...
@@ -336,8 +335,7 @@ int run(int argc, char* argv[])
...
@@ -336,8 +335,7 @@ int run(int argc, char* argv[])
acc0_element_op
,
acc0_element_op
,
b1_element_op
,
b1_element_op
,
c_element_op
,
c_element_op
,
p_drop
,
// dropout ratio
p_drop
,
// dropout ratio
h_ratio
,
{
seed
,
offset
});
// dropout random seed and offset, offset should be
{
seed
,
offset
});
// dropout random seed and offset, offset should be
// at least the number of elements on a thread
// at least the number of elements on a thread
// specify workspace for problem_desc
// specify workspace for problem_desc
...
...
include/ck/tensor_operation/gpu/device/device_grouped_gemm_softmax_gemm_permute.hpp
View file @
9574b34d
...
@@ -134,7 +134,6 @@ struct DeviceGroupedMultiheadAttentionForward : public BaseOperator
...
@@ -134,7 +134,6 @@ struct DeviceGroupedMultiheadAttentionForward : public BaseOperator
B1ElementwiseOperation
b1_element_op
,
B1ElementwiseOperation
b1_element_op
,
CElementwiseOperation
c_element_op
,
CElementwiseOperation
c_element_op
,
float
p_dropout
,
float
p_dropout
,
index_t
h_ratio
,
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
seeds
)
=
0
;
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
seeds
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_light_v1.hpp
View file @
9574b34d
...
@@ -1321,8 +1321,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
...
@@ -1321,8 +1321,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
const
index_t
b1_gemm1n
=
const
index_t
b1_gemm1n
=
arg
.
b1_grid_desc_bk0_n_bk1_
.
GetLength
(
I0
)
*
arg
.
b1_grid_desc_bk0_n_bk1_
.
GetLength
(
I2
);
arg
.
b1_grid_desc_bk0_n_bk1_
.
GetLength
(
I0
)
*
arg
.
b1_grid_desc_bk0_n_bk1_
.
GetLength
(
I2
);
if
(
!
(
c_g
==
arg
.
batch_count_
&&
c_m
==
a_m
&&
c_gemm1n
==
b1_gemm1n
&&
c_g
%
b_g
==
0
&&
if
(
!
(
c_g
==
arg
.
batch_count_
&&
c_m
==
a_m
&&
c_gemm1n
==
b1_gemm1n
&&
c_g
%
b_g
==
0
))
b_g
<=
c_g
))
{
{
return
false
;
return
false
;
}
}
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_light_v2.hpp
View file @
9574b34d
...
@@ -1353,8 +1353,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -1353,8 +1353,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
const
index_t
b1_gemm1n
=
const
index_t
b1_gemm1n
=
arg
.
b1_grid_desc_bk0_n_bk1_
.
GetLength
(
I0
)
*
arg
.
b1_grid_desc_bk0_n_bk1_
.
GetLength
(
I2
);
arg
.
b1_grid_desc_bk0_n_bk1_
.
GetLength
(
I0
)
*
arg
.
b1_grid_desc_bk0_n_bk1_
.
GetLength
(
I2
);
if
(
!
(
c_g
==
arg
.
batch_count_
&&
c_m
==
a_m
&&
c_gemm1n
==
b1_gemm1n
&&
c_g
%
b_g
==
0
&&
if
(
!
(
c_g
==
arg
.
batch_count_
&&
c_m
==
a_m
&&
c_gemm1n
==
b1_gemm1n
&&
c_g
%
b_g
==
0
))
b_g
<=
c_g
))
{
{
return
false
;
return
false
;
}
}
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_v1.hpp
View file @
9574b34d
...
@@ -1180,8 +1180,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -1180,8 +1180,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
const
index_t
b1_gemm1n
=
const
index_t
b1_gemm1n
=
arg
.
b1_grid_desc_bk0_n_bk1_
.
GetLength
(
I0
)
*
arg
.
b1_grid_desc_bk0_n_bk1_
.
GetLength
(
I2
);
arg
.
b1_grid_desc_bk0_n_bk1_
.
GetLength
(
I0
)
*
arg
.
b1_grid_desc_bk0_n_bk1_
.
GetLength
(
I2
);
if
(
!
(
c_g
==
arg
.
batch_count_
&&
c_m
==
a_m
&&
c_gemm1n
==
b1_gemm1n
&&
c_g
%
b_g
==
0
&&
if
(
!
(
c_g
==
arg
.
batch_count_
&&
c_m
==
a_m
&&
c_gemm1n
==
b1_gemm1n
&&
c_g
%
b_g
==
0
))
b_g
<=
c_g
))
{
{
return
false
;
return
false
;
}
}
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_v2.hpp
View file @
9574b34d
...
@@ -1213,8 +1213,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -1213,8 +1213,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
const
index_t
b1_gemm1n
=
const
index_t
b1_gemm1n
=
arg
.
b1_grid_desc_bk0_n_bk1_
.
GetLength
(
I0
)
*
arg
.
b1_grid_desc_bk0_n_bk1_
.
GetLength
(
I2
);
arg
.
b1_grid_desc_bk0_n_bk1_
.
GetLength
(
I0
)
*
arg
.
b1_grid_desc_bk0_n_bk1_
.
GetLength
(
I2
);
if
(
!
(
c_g
==
arg
.
batch_count_
&&
c_m
==
a_m
&&
c_gemm1n
==
b1_gemm1n
&&
c_g
%
b_g
==
0
&&
if
(
!
(
c_g
==
arg
.
batch_count_
&&
c_m
==
a_m
&&
c_gemm1n
==
b1_gemm1n
&&
c_g
%
b_g
==
0
))
b_g
<=
c_g
))
{
{
return
false
;
return
false
;
}
}
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_fwd_xdl_cshuffle_v2.hpp
View file @
9574b34d
...
@@ -1024,8 +1024,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
...
@@ -1024,8 +1024,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
const
index_t
a_m
=
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I1
);
const
index_t
a_m
=
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I1
);
const
index_t
b1_gemm1n
=
arg
.
b1_grid_desc_bk0_n_bk1_
.
GetLength
(
I1
);
const
index_t
b1_gemm1n
=
arg
.
b1_grid_desc_bk0_n_bk1_
.
GetLength
(
I1
);
if
(
!
(
c_g
==
arg
.
batch_count_
&&
c_m
==
a_m
&&
c_gemm1n
==
b1_gemm1n
&&
c_g
%
b_g
==
0
&&
if
(
!
(
c_g
==
arg
.
batch_count_
&&
c_m
==
a_m
&&
c_gemm1n
==
b1_gemm1n
&&
c_g
%
b_g
==
0
))
b_g
<=
c_g
))
{
{
return
false
;
return
false
;
}
}
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_light_v1.hpp
View file @
9574b34d
...
@@ -930,15 +930,13 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
...
@@ -930,15 +930,13 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
B1ElementwiseOperation
b1_element_op
,
B1ElementwiseOperation
b1_element_op
,
CElementwiseOperation
c_element_op
,
CElementwiseOperation
c_element_op
,
float
p_drop
,
float
p_drop
,
index_t
h_ratio
,
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
seeds
)
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
seeds
)
:
a_element_op_
{
a_element_op
},
:
a_element_op_
{
a_element_op
},
b_element_op_
{
b_element_op
},
b_element_op_
{
b_element_op
},
acc_element_op_
{
acc_element_op
},
acc_element_op_
{
acc_element_op
},
b1_element_op_
{
b1_element_op
},
b1_element_op_
{
b1_element_op
},
c_element_op_
{
c_element_op
},
c_element_op_
{
c_element_op
},
p_dropout_
{
p_drop
},
p_dropout_
{
p_drop
}
h_ratio_
{
h_ratio
}
{
{
seed_
=
std
::
get
<
0
>
(
seeds
);
seed_
=
std
::
get
<
0
>
(
seeds
);
offset_
=
std
::
get
<
1
>
(
seeds
);
offset_
=
std
::
get
<
1
>
(
seeds
);
...
@@ -972,6 +970,9 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
...
@@ -972,6 +970,9 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
d_grid_size_
=
0
;
d_grid_size_
=
0
;
h_ratio_
=
problem_desc_vec
[
0
].
a_gs_ms_ks_lengths
[
NumDimG
-
1
]
/
problem_desc_vec
[
0
].
b_gs_ns_ks_lengths
[
NumDimG
-
1
];
for
(
index_t
i
=
0
;
i
<
group_count_
;
i
++
)
for
(
index_t
i
=
0
;
i
<
group_count_
;
i
++
)
{
{
const
auto
p_a_grid
=
static_cast
<
const
InputDataType
*>
(
p_As
[
i
]);
const
auto
p_a_grid
=
static_cast
<
const
InputDataType
*>
(
p_As
[
i
]);
...
@@ -1453,7 +1454,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
...
@@ -1453,7 +1454,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
B1ElementwiseOperation
b1_element_op
,
B1ElementwiseOperation
b1_element_op
,
CElementwiseOperation
c_element_op
,
CElementwiseOperation
c_element_op
,
float
p_drop
,
float
p_drop
,
index_t
h_ratio
,
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
seeds
)
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
seeds
)
{
{
return
Argument
{
p_As
,
return
Argument
{
p_As
,
...
@@ -1478,7 +1478,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
...
@@ -1478,7 +1478,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
b1_element_op
,
b1_element_op
,
c_element_op
,
c_element_op
,
p_drop
,
p_drop
,
h_ratio
,
seeds
};
seeds
};
}
}
...
@@ -1509,7 +1508,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
...
@@ -1509,7 +1508,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
B1ElementwiseOperation
b1_element_op
,
B1ElementwiseOperation
b1_element_op
,
CElementwiseOperation
c_element_op
,
CElementwiseOperation
c_element_op
,
float
p_drop
,
float
p_drop
,
index_t
h_ratio
,
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
seeds
)
// override
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
seeds
)
// override
{
{
return
std
::
make_unique
<
Argument
>
(
p_As
,
return
std
::
make_unique
<
Argument
>
(
p_As
,
...
@@ -1534,7 +1532,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
...
@@ -1534,7 +1532,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
b1_element_op
,
b1_element_op
,
c_element_op
,
c_element_op
,
p_drop
,
p_drop
,
h_ratio
,
seeds
);
seeds
);
}
}
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_light_v2.hpp
View file @
9574b34d
...
@@ -999,15 +999,13 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -999,15 +999,13 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
B1ElementwiseOperation
b1_element_op
,
B1ElementwiseOperation
b1_element_op
,
CElementwiseOperation
c_element_op
,
CElementwiseOperation
c_element_op
,
float
p_drop
,
float
p_drop
,
index_t
h_ratio
,
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
seeds
)
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
seeds
)
:
a_element_op_
{
a_element_op
},
:
a_element_op_
{
a_element_op
},
b_element_op_
{
b_element_op
},
b_element_op_
{
b_element_op
},
acc_element_op_
{
acc_element_op
},
acc_element_op_
{
acc_element_op
},
b1_element_op_
{
b1_element_op
},
b1_element_op_
{
b1_element_op
},
c_element_op_
{
c_element_op
},
c_element_op_
{
c_element_op
},
p_dropout_
{
p_drop
},
p_dropout_
{
p_drop
}
h_ratio_
{
h_ratio
}
{
{
seed_
=
std
::
get
<
0
>
(
seeds
);
seed_
=
std
::
get
<
0
>
(
seeds
);
offset_
=
std
::
get
<
1
>
(
seeds
);
offset_
=
std
::
get
<
1
>
(
seeds
);
...
@@ -1041,6 +1039,9 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -1041,6 +1039,9 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
d_grid_size_
=
0
;
d_grid_size_
=
0
;
h_ratio_
=
problem_desc_vec
[
0
].
a_gs_ms_ks_lengths
[
NumDimG
-
1
]
/
problem_desc_vec
[
0
].
b_gs_ns_ks_lengths
[
NumDimG
-
1
];
for
(
index_t
i
=
0
;
i
<
group_count_
;
i
++
)
for
(
index_t
i
=
0
;
i
<
group_count_
;
i
++
)
{
{
const
auto
p_a_grid
=
static_cast
<
const
InputDataType
*>
(
p_As
[
i
]);
const
auto
p_a_grid
=
static_cast
<
const
InputDataType
*>
(
p_As
[
i
]);
...
@@ -1527,7 +1528,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -1527,7 +1528,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
B1ElementwiseOperation
b1_element_op
,
B1ElementwiseOperation
b1_element_op
,
CElementwiseOperation
c_element_op
,
CElementwiseOperation
c_element_op
,
float
p_drop
,
float
p_drop
,
index_t
h_ratio
,
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
seeds
)
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
seeds
)
{
{
return
Argument
{
p_As
,
return
Argument
{
p_As
,
...
@@ -1552,7 +1552,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -1552,7 +1552,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
b1_element_op
,
b1_element_op
,
c_element_op
,
c_element_op
,
p_drop
,
p_drop
,
h_ratio
,
seeds
};
seeds
};
}
}
...
@@ -1583,7 +1582,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -1583,7 +1582,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
B1ElementwiseOperation
b1_element_op
,
B1ElementwiseOperation
b1_element_op
,
CElementwiseOperation
c_element_op
,
CElementwiseOperation
c_element_op
,
float
p_drop
,
float
p_drop
,
index_t
h_ratio
,
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
seeds
)
// override
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
seeds
)
// override
{
{
return
std
::
make_unique
<
Argument
>
(
p_As
,
return
std
::
make_unique
<
Argument
>
(
p_As
,
...
@@ -1608,7 +1606,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -1608,7 +1606,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
b1_element_op
,
b1_element_op
,
c_element_op
,
c_element_op
,
p_drop
,
p_drop
,
h_ratio
,
seeds
);
seeds
);
}
}
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_v1.hpp
View file @
9574b34d
...
@@ -812,15 +812,13 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -812,15 +812,13 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
B1ElementwiseOperation
b1_element_op
,
B1ElementwiseOperation
b1_element_op
,
CElementwiseOperation
c_element_op
,
CElementwiseOperation
c_element_op
,
float
p_drop
,
float
p_drop
,
index_t
h_ratio
,
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
seeds
)
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
seeds
)
:
a_element_op_
{
a_element_op
},
:
a_element_op_
{
a_element_op
},
b_element_op_
{
b_element_op
},
b_element_op_
{
b_element_op
},
acc_element_op_
{
acc_element_op
},
acc_element_op_
{
acc_element_op
},
b1_element_op_
{
b1_element_op
},
b1_element_op_
{
b1_element_op
},
c_element_op_
{
c_element_op
},
c_element_op_
{
c_element_op
},
p_dropout_
{
p_drop
},
p_dropout_
{
p_drop
}
h_ratio_
{
h_ratio
}
{
{
seed_
=
std
::
get
<
0
>
(
seeds
);
seed_
=
std
::
get
<
0
>
(
seeds
);
offset_
=
std
::
get
<
1
>
(
seeds
);
offset_
=
std
::
get
<
1
>
(
seeds
);
...
@@ -851,6 +849,9 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -851,6 +849,9 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
index_t
z_random_matrix_offset
=
0
;
index_t
z_random_matrix_offset
=
0
;
h_ratio_
=
problem_desc_vec
[
0
].
a_gs_ms_ks_lengths
[
NumDimG
-
1
]
/
problem_desc_vec
[
0
].
b_gs_ns_ks_lengths
[
NumDimG
-
1
];
for
(
index_t
i
=
0
;
i
<
group_count_
;
i
++
)
for
(
index_t
i
=
0
;
i
<
group_count_
;
i
++
)
{
{
const
auto
p_a_grid
=
static_cast
<
const
InputDataType
*>
(
p_As
[
i
]);
const
auto
p_a_grid
=
static_cast
<
const
InputDataType
*>
(
p_As
[
i
]);
...
@@ -1297,7 +1298,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -1297,7 +1298,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
B1ElementwiseOperation
b1_element_op
,
B1ElementwiseOperation
b1_element_op
,
CElementwiseOperation
c_element_op
,
CElementwiseOperation
c_element_op
,
float
p_drop
,
float
p_drop
,
index_t
h_ratio
,
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
seeds
)
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
seeds
)
{
{
return
Argument
{
p_As
,
return
Argument
{
p_As
,
...
@@ -1321,7 +1321,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -1321,7 +1321,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
b1_element_op
,
b1_element_op
,
c_element_op
,
c_element_op
,
p_drop
,
p_drop
,
h_ratio
,
seeds
};
seeds
};
}
}
...
@@ -1351,7 +1350,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -1351,7 +1350,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
B1ElementwiseOperation
b1_element_op
,
B1ElementwiseOperation
b1_element_op
,
CElementwiseOperation
c_element_op
,
CElementwiseOperation
c_element_op
,
float
p_drop
,
float
p_drop
,
index_t
h_ratio
,
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
seeds
)
// override
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
seeds
)
// override
{
{
return
std
::
make_unique
<
Argument
>
(
p_As
,
return
std
::
make_unique
<
Argument
>
(
p_As
,
...
@@ -1375,7 +1373,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -1375,7 +1373,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
b1_element_op
,
b1_element_op
,
c_element_op
,
c_element_op
,
p_drop
,
p_drop
,
h_ratio
,
seeds
);
seeds
);
}
}
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_v2.hpp
View file @
9574b34d
...
@@ -882,15 +882,13 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -882,15 +882,13 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
B1ElementwiseOperation
b1_element_op
,
B1ElementwiseOperation
b1_element_op
,
CElementwiseOperation
c_element_op
,
CElementwiseOperation
c_element_op
,
float
p_drop
,
float
p_drop
,
index_t
h_ratio
,
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
seeds
)
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
seeds
)
:
a_element_op_
{
a_element_op
},
:
a_element_op_
{
a_element_op
},
b_element_op_
{
b_element_op
},
b_element_op_
{
b_element_op
},
acc_element_op_
{
acc_element_op
},
acc_element_op_
{
acc_element_op
},
b1_element_op_
{
b1_element_op
},
b1_element_op_
{
b1_element_op
},
c_element_op_
{
c_element_op
},
c_element_op_
{
c_element_op
},
p_dropout_
{
p_drop
},
p_dropout_
{
p_drop
}
h_ratio_
{
h_ratio
}
{
{
seed_
=
std
::
get
<
0
>
(
seeds
);
seed_
=
std
::
get
<
0
>
(
seeds
);
offset_
=
std
::
get
<
1
>
(
seeds
);
offset_
=
std
::
get
<
1
>
(
seeds
);
...
@@ -921,6 +919,9 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -921,6 +919,9 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
index_t
z_random_matrix_offset
=
0
;
index_t
z_random_matrix_offset
=
0
;
h_ratio_
=
problem_desc_vec
[
0
].
a_gs_ms_ks_lengths
[
NumDimG
-
1
]
/
problem_desc_vec
[
0
].
b_gs_ns_ks_lengths
[
NumDimG
-
1
];
for
(
index_t
i
=
0
;
i
<
group_count_
;
i
++
)
for
(
index_t
i
=
0
;
i
<
group_count_
;
i
++
)
{
{
const
auto
p_a_grid
=
static_cast
<
const
InputDataType
*>
(
p_As
[
i
]);
const
auto
p_a_grid
=
static_cast
<
const
InputDataType
*>
(
p_As
[
i
]);
...
@@ -1372,7 +1373,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -1372,7 +1373,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
B1ElementwiseOperation
b1_element_op
,
B1ElementwiseOperation
b1_element_op
,
CElementwiseOperation
c_element_op
,
CElementwiseOperation
c_element_op
,
float
p_drop
,
float
p_drop
,
index_t
h_ratio
,
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
seeds
)
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
seeds
)
{
{
return
Argument
{
p_As
,
return
Argument
{
p_As
,
...
@@ -1396,7 +1396,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -1396,7 +1396,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
b1_element_op
,
b1_element_op
,
c_element_op
,
c_element_op
,
p_drop
,
p_drop
,
h_ratio
,
seeds
};
seeds
};
}
}
...
@@ -1426,7 +1425,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -1426,7 +1425,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
B1ElementwiseOperation
b1_element_op
,
B1ElementwiseOperation
b1_element_op
,
CElementwiseOperation
c_element_op
,
CElementwiseOperation
c_element_op
,
float
p_drop
,
float
p_drop
,
index_t
h_ratio
,
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
seeds
)
// override
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
seeds
)
// override
{
{
return
std
::
make_unique
<
Argument
>
(
p_As
,
return
std
::
make_unique
<
Argument
>
(
p_As
,
...
@@ -1450,7 +1448,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -1450,7 +1448,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
b1_element_op
,
b1_element_op
,
c_element_op
,
c_element_op
,
p_drop
,
p_drop
,
h_ratio
,
seeds
);
seeds
);
}
}
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_fwd_xdl_cshuffle_v2.hpp
View file @
9574b34d
...
@@ -682,14 +682,12 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
...
@@ -682,14 +682,12 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
B1ElementwiseOperation
b1_element_op
,
B1ElementwiseOperation
b1_element_op
,
CElementwiseOperation
c_element_op
,
CElementwiseOperation
c_element_op
,
float
p_dropout
,
float
p_dropout
,
index_t
h_ratio
,
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
seeds
)
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
seeds
)
:
a_element_op_
{
a_element_op
},
:
a_element_op_
{
a_element_op
},
b_element_op_
{
b_element_op
},
b_element_op_
{
b_element_op
},
acc_element_op_
{
acc_element_op
},
acc_element_op_
{
acc_element_op
},
b1_element_op_
{
b1_element_op
},
b1_element_op_
{
b1_element_op
},
c_element_op_
{
c_element_op
},
c_element_op_
{
c_element_op
}
h_ratio_
{
h_ratio
}
{
{
ignore
=
p_acc1_biases_vec
;
ignore
=
p_acc1_biases_vec
;
// TODO ANT: implement bias addition
// TODO ANT: implement bias addition
...
@@ -708,6 +706,9 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
...
@@ -708,6 +706,9 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
index_t
z_random_matrix_offset
=
0
;
index_t
z_random_matrix_offset
=
0
;
h_ratio_
=
problem_desc_vec
[
0
].
a_gs_ms_ks_lengths
[
NumDimG
-
1
]
/
problem_desc_vec
[
0
].
b0_gs_ns_ks_lengths
[
NumDimG
-
1
];
for
(
std
::
size_t
i
=
0
;
i
<
group_count_
;
i
++
)
for
(
std
::
size_t
i
=
0
;
i
<
group_count_
;
i
++
)
{
{
const
auto
p_a_grid
=
static_cast
<
const
ADataType
*>
(
p_a_vec
[
i
]);
const
auto
p_a_grid
=
static_cast
<
const
ADataType
*>
(
p_a_vec
[
i
]);
...
@@ -1214,7 +1215,6 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
...
@@ -1214,7 +1215,6 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
B1ElementwiseOperation
b1_element_op
,
B1ElementwiseOperation
b1_element_op
,
CElementwiseOperation
c_element_op
,
CElementwiseOperation
c_element_op
,
float
p_dropout
,
float
p_dropout
,
index_t
h_ratio
,
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
seeds
)
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
seeds
)
{
{
return
Argument
{
p_a_vec
,
return
Argument
{
p_a_vec
,
...
@@ -1232,7 +1232,6 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
...
@@ -1232,7 +1232,6 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
b1_element_op
,
b1_element_op
,
c_element_op
,
c_element_op
,
p_dropout
,
p_dropout
,
h_ratio
,
seeds
};
seeds
};
}
}
...
@@ -1255,7 +1254,6 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
...
@@ -1255,7 +1254,6 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
B1ElementwiseOperation
b1_element_op
,
B1ElementwiseOperation
b1_element_op
,
CElementwiseOperation
c_element_op
,
CElementwiseOperation
c_element_op
,
float
p_dropout
,
float
p_dropout
,
index_t
h_ratio
,
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
seeds
)
override
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
seeds
)
override
{
{
return
std
::
make_unique
<
Argument
>
(
p_a_vec
,
return
std
::
make_unique
<
Argument
>
(
p_a_vec
,
...
@@ -1273,7 +1271,6 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
...
@@ -1273,7 +1271,6 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
b1_element_op
,
b1_element_op
,
c_element_op
,
c_element_op
,
p_dropout
,
p_dropout
,
h_ratio
,
seeds
);
seeds
);
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment