Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
198558c5
Commit
198558c5
authored
Sep 26, 2023
by
danyao12
Browse files
train mqa/gqa
parent
6a2d7c9f
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
129 additions
and
58 deletions
+129
-58
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_v2.cpp
..._softmax_gemm/batched_multihead_attention_backward_v2.cpp
+1
-1
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_train_v2.cpp
...ale_softmax_gemm/batched_multihead_attention_train_v2.cpp
+58
-23
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_train_v2.cpp
...ale_softmax_gemm/grouped_multihead_attention_train_v2.cpp
+70
-34
No files found.
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_v2.cpp
View file @
198558c5
...
...
@@ -277,7 +277,7 @@ int run(int argc, char* argv[])
ck
::
index_t
O
=
DIM
;
ck
::
index_t
G0
=
4
;
ck
::
index_t
G1
=
6
;
// h_q
ck
::
index_t
G2
=
1
;
// h_kv
ck
::
index_t
G2
=
6
;
// h_kv
bool
input_permute
=
false
;
bool
output_permute
=
false
;
...
...
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_train_v2.cpp
View file @
198558c5
...
...
@@ -305,8 +305,9 @@ int run(int argc, char* argv[])
ck
::
index_t
M
=
500
;
// 512
ck
::
index_t
K
=
DIM
;
ck
::
index_t
O
=
DIM
;
ck
::
index_t
G0
=
4
;
// 54
ck
::
index_t
G1
=
6
;
// 16
ck
::
index_t
G0
=
4
;
ck
::
index_t
G1
=
6
;
// h_q
ck
::
index_t
G2
=
6
;
// h_kv
bool
input_permute
=
false
;
bool
output_permute
=
false
;
...
...
@@ -325,7 +326,7 @@ int run(int argc, char* argv[])
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
}
else
if
(
argc
==
1
3
)
else
if
(
argc
==
1
4
)
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
...
...
@@ -337,20 +338,21 @@ int run(int argc, char* argv[])
O
=
std
::
stoi
(
argv
[
7
]);
G0
=
std
::
stoi
(
argv
[
8
]);
G1
=
std
::
stoi
(
argv
[
9
]);
G2
=
std
::
stoi
(
argv
[
10
]);
p_drop
=
std
::
stof
(
argv
[
1
0
]);
p_drop
=
std
::
stof
(
argv
[
1
1
]);
input_permute
=
std
::
stoi
(
argv
[
1
1
]);
output_permute
=
std
::
stoi
(
argv
[
1
2
]);
input_permute
=
std
::
stoi
(
argv
[
1
2
]);
output_permute
=
std
::
stoi
(
argv
[
1
3
]);
}
else
{
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg3: time kernel (0=no, 1=yes)
\n
"
);
printf
(
"arg4 to 1
1
: M, N, K, O, G0, G1
\n
"
);
printf
(
"arg1
0
:
scale (alpha)
\n
"
);
printf
(
"arg1
1
to 1
2
: input / output permute
\n
"
);
printf
(
"arg4 to 1
0
: M, N, K, O, G0, G1
, G2
\n
"
);
printf
(
"arg1
1
:
p_drop
\n
"
);
printf
(
"arg1
2
to 1
3
: input / output permute
\n
"
);
exit
(
0
);
}
...
...
@@ -368,6 +370,7 @@ int run(int argc, char* argv[])
std
::
cout
<<
"O: "
<<
O
<<
std
::
endl
;
std
::
cout
<<
"G0: "
<<
G0
<<
std
::
endl
;
std
::
cout
<<
"G1: "
<<
G1
<<
std
::
endl
;
std
::
cout
<<
"G2: "
<<
G2
<<
std
::
endl
;
std
::
cout
<<
"alpha: "
<<
alpha
<<
std
::
endl
;
std
::
cout
<<
"input_permute: "
<<
input_permute
<<
std
::
endl
;
std
::
cout
<<
"output_permute: "
<<
output_permute
<<
std
::
endl
;
...
...
@@ -383,17 +386,17 @@ int run(int argc, char* argv[])
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
K
,
K
,
G1
*
K
,
1
}
// Q layout [G0, M, G1, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
K
,
M
*
K
,
K
,
1
};
// Q layout [G0, G1, M, K]
std
::
vector
<
ck
::
index_t
>
k_gs_ns_ks_lengths
{
G0
,
G
1
,
N
,
K
};
std
::
vector
<
ck
::
index_t
>
k_gs_ns_ks_lengths
{
G0
,
G
2
,
N
,
K
};
std
::
vector
<
ck
::
index_t
>
k_gs_ns_ks_strides
=
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G
1
*
K
,
K
,
G
1
*
K
,
1
}
// K layout [G0, N, G
1
, K]
:
std
::
vector
<
ck
::
index_t
>
{
G
1
*
N
*
K
,
N
*
K
,
K
,
1
};
// K layout [G0, G
1
, N, K]
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G
2
*
K
,
K
,
G
2
*
K
,
1
}
// K layout [G0, N, G
2
, K]
:
std
::
vector
<
ck
::
index_t
>
{
G
2
*
N
*
K
,
N
*
K
,
K
,
1
};
// K layout [G0, G
2
, N, K]
std
::
vector
<
ck
::
index_t
>
v_gs_os_ns_lengths
{
G0
,
G
1
,
O
,
N
};
std
::
vector
<
ck
::
index_t
>
v_gs_os_ns_lengths
{
G0
,
G
2
,
O
,
N
};
std
::
vector
<
ck
::
index_t
>
v_gs_os_ns_strides
=
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G
1
*
O
,
O
,
1
,
G
1
*
O
}
// V layout [G0, N, G
1
, O]
:
std
::
vector
<
ck
::
index_t
>
{
G
1
*
N
*
O
,
N
*
O
,
1
,
O
};
// V layout [G0, G
1
, N, O]
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G
2
*
O
,
O
,
1
,
G
2
*
O
}
// V layout [G0, N, G
2
, O]
:
std
::
vector
<
ck
::
index_t
>
{
G
2
*
N
*
O
,
N
*
O
,
1
,
O
};
// V layout [G0, G
2
, N, O]
std
::
vector
<
ck
::
index_t
>
y_gs_ms_os_lengths
{
G0
,
G1
,
M
,
O
};
std
::
vector
<
ck
::
index_t
>
y_gs_ms_os_strides
=
...
...
@@ -406,6 +409,18 @@ int run(int argc, char* argv[])
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
N
,
N
,
G1
*
N
,
1
}
// Z layout [G0, M, G1, N]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
N
,
M
*
N
,
N
,
1
};
// Z layout [G0, G1, M, N]
std
::
vector
<
ck
::
index_t
>
kgrad_gs_ns_ks_lengths
{
G0
,
G1
,
N
,
K
};
std
::
vector
<
ck
::
index_t
>
kgrad_gs_ns_ks_strides
=
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1
*
K
,
K
,
G1
*
K
,
1
}
// KGrad layout [G0, N, G1, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
N
*
K
,
N
*
K
,
K
,
1
};
// KGrad layout [G0, G1, N, K]
std
::
vector
<
ck
::
index_t
>
vgrad_gs_os_ns_lengths
{
G0
,
G1
,
O
,
N
};
std
::
vector
<
ck
::
index_t
>
vgrad_gs_os_ns_strides
=
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1
*
O
,
O
,
1
,
G1
*
O
}
// VGrad layout [G0, N, G1, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
N
*
O
,
N
*
O
,
1
,
O
};
// VGrad layout [G0, G1, N, O]
// The softmax stat log-sum-exp (LSE) is used to speed up softmax calculation in backward pass
// Pi = exp(Si) / sum(exp(S0) + exp(S1) + ...)
// = exp(Si) / exp(log(sum(exp() + ...)))
...
...
@@ -424,8 +439,10 @@ int run(int argc, char* argv[])
Tensor
<
InputDataType
>
y_gs_ms_os_device_result
(
y_gs_ms_os_lengths
,
y_gs_ms_os_strides
);
Tensor
<
LSEDataType
>
lse_gs_ms_device_result
(
lse_gs_ms_lengths
,
lse_gs_ms_strides
);
Tensor
<
OutputDataType
>
qgrad_gs_ms_ks_device_result
(
q_gs_ms_ks_lengths
,
q_gs_ms_ks_strides
);
Tensor
<
OutputDataType
>
kgrad_gs_ns_ks_device_result
(
k_gs_ns_ks_lengths
,
k_gs_ns_ks_strides
);
Tensor
<
OutputDataType
>
vgrad_gs_os_ns_device_result
(
v_gs_os_ns_lengths
,
v_gs_os_ns_strides
);
Tensor
<
OutputDataType
>
kgrad_gs_ns_ks_device_result
(
kgrad_gs_ns_ks_lengths
,
kgrad_gs_ns_ks_strides
);
Tensor
<
OutputDataType
>
vgrad_gs_os_ns_device_result
(
vgrad_gs_os_ns_lengths
,
vgrad_gs_os_ns_strides
);
std
::
cout
<<
"q_gs_ms_ks: "
<<
q_gs_ms_ks
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"k_gs_ns_ks: "
<<
k_gs_ns_ks
.
mDesc
<<
std
::
endl
;
...
...
@@ -612,6 +629,10 @@ int run(int argc, char* argv[])
y_gs_ms_os_lengths
,
y_gs_ms_os_strides
,
lse_gs_ms_lengths
,
kgrad_gs_ns_ks_lengths
,
kgrad_gs_ns_ks_strides
,
vgrad_gs_os_ns_lengths
,
vgrad_gs_os_ns_strides
,
{},
// std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_lengths},
{},
// std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_strides},
{},
// std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_lengths},
...
...
@@ -656,8 +677,10 @@ int run(int argc, char* argv[])
Tensor
<
InputDataType
>
y_gs_ms_os_host_result
(
y_gs_ms_os_lengths
,
y_gs_ms_os_strides
);
Tensor
<
LSEDataType
>
lse_gs_ms_host_result
(
lse_gs_ms_lengths
,
lse_gs_ms_strides
);
Tensor
<
OutputDataType
>
qgrad_gs_ms_ks_host_result
(
q_gs_ms_ks_lengths
,
q_gs_ms_ks_strides
);
Tensor
<
OutputDataType
>
kgrad_gs_ns_ks_host_result
(
k_gs_ns_ks_lengths
,
k_gs_ns_ks_strides
);
Tensor
<
OutputDataType
>
vgrad_gs_os_ns_host_result
(
v_gs_os_ns_lengths
,
v_gs_os_ns_strides
);
Tensor
<
OutputDataType
>
kgrad_gs_ns_ks_host_result
(
kgrad_gs_ns_ks_lengths
,
kgrad_gs_ns_ks_strides
);
Tensor
<
OutputDataType
>
vgrad_gs_os_ns_host_result
(
vgrad_gs_os_ns_lengths
,
vgrad_gs_os_ns_strides
);
Tensor
<
InputDataType
>
q_g_m_k
({
BatchCount
,
M
,
K
});
Tensor
<
InputDataType
>
k_g_n_k
({
BatchCount
,
N
,
K
});
...
...
@@ -760,6 +783,10 @@ int run(int argc, char* argv[])
y_gs_ms_os_lengths
,
y_gs_ms_os_strides
,
lse_gs_ms_lengths
,
kgrad_gs_ns_ks_lengths
,
kgrad_gs_ns_ks_strides
,
vgrad_gs_os_ns_lengths
,
vgrad_gs_os_ns_strides
,
{},
// std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_lengths},
{},
// std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_strides},
{},
// std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_lengths},
...
...
@@ -795,11 +822,19 @@ int run(int argc, char* argv[])
q_gs_ms_ks
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
q_g_m_k
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
k_gs_ns_ks
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
k_g_n_k
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
k_g_n_k
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
]
/
G1
;
const
size_t
&
g1
=
idx
[
0
]
%
G1
;
const
size_t
&
g2
=
g1
/
(
G1
/
G2
);
self
(
idx
)
=
k_gs_ns_ks
(
g0
,
g2
,
idx
[
1
],
idx
[
2
]);
});
v_gs_os_ns
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
v_g_n_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
3
],
idx
[
2
])
=
self
(
idx
);
v_g_n_o
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
]
/
G1
;
const
size_t
&
g1
=
idx
[
0
]
%
G1
;
const
size_t
&
g2
=
g1
/
(
G1
/
G2
);
self
(
idx
)
=
v_gs_os_ns
(
g0
,
g2
,
idx
[
2
],
idx
[
1
]);
});
z_fwd_gs_ms_ns
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
z_fwd_g_m_n
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
...
...
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_train_v2.cpp
View file @
198558c5
...
...
@@ -302,6 +302,7 @@ int run(int argc, char* argv[])
// y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3])
float
alpha
=
1.
f
/
std
::
sqrt
(
DIM
);
float
p_drop
=
0.2
;
int
h_ratio
=
1
;
// G1 / G2
bool
input_permute
=
true
;
bool
output_permute
=
true
;
...
...
@@ -319,25 +320,26 @@ int run(int argc, char* argv[])
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
}
else
if
(
argc
==
7
)
else
if
(
argc
==
8
)
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
p_drop
=
std
::
stof
(
argv
[
4
]);
p_drop
=
std
::
stof
(
argv
[
4
]);
h_ratio
=
std
::
stof
(
argv
[
5
]);
input_permute
=
std
::
stoi
(
argv
[
5
]);
output_permute
=
std
::
stoi
(
argv
[
6
]);
input_permute
=
std
::
stoi
(
argv
[
6
]);
output_permute
=
std
::
stoi
(
argv
[
7
]);
}
else
{
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg3: time kernel (0=no, 1=yes)
\n
"
);
printf
(
"arg4
to 11: M, N, K, O, G0, G1
\n
"
);
printf
(
"arg
10: scale (alpha)
\n
"
);
printf
(
"arg
11
to
12
: input / output permute
\n
"
);
printf
(
"arg4
: p_drop
\n
"
);
printf
(
"arg
5: h_ratio
\n
"
);
printf
(
"arg
6
to
7
: input / output permute
\n
"
);
exit
(
0
);
}
...
...
@@ -412,24 +414,25 @@ int run(int argc, char* argv[])
int
K
=
DIM
;
int
O
=
DIM
;
int
G0
=
rand
()
%
4
+
1
;
int
G1
=
rand
()
%
4
+
1
;
int
G2
=
rand
()
%
4
+
1
;
int
G1
=
G2
*
h_ratio
;
std
::
vector
<
ck
::
index_t
>
q_gs_ms_ks_lengths
{
G0
,
G1
,
M
,
K
};
std
::
vector
<
ck
::
index_t
>
q_gs_ms_ks_strides
=
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
K
,
K
,
G1
*
K
,
1
}
// Q layout [G0, M, G1, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
K
,
M
*
K
,
K
,
1
};
// Q layout [G0, G1, M, K]
std
::
vector
<
ck
::
index_t
>
k_gs_ns_ks_lengths
{
G0
,
G
1
,
N
,
K
};
std
::
vector
<
ck
::
index_t
>
k_gs_ns_ks_lengths
{
G0
,
G
2
,
N
,
K
};
std
::
vector
<
ck
::
index_t
>
k_gs_ns_ks_strides
=
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G
1
*
K
,
K
,
G
1
*
K
,
1
}
// K layout [G0, N, G
1
, K]
:
std
::
vector
<
ck
::
index_t
>
{
G
1
*
N
*
K
,
N
*
K
,
K
,
1
};
// K layout [G0, G
1
, N, K]
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G
2
*
K
,
K
,
G
2
*
K
,
1
}
// K layout [G0, N, G
2
, K]
:
std
::
vector
<
ck
::
index_t
>
{
G
2
*
N
*
K
,
N
*
K
,
K
,
1
};
// K layout [G0, G
2
, N, K]
std
::
vector
<
ck
::
index_t
>
v_gs_os_ns_lengths
{
G0
,
G
1
,
O
,
N
};
std
::
vector
<
ck
::
index_t
>
v_gs_os_ns_lengths
{
G0
,
G
2
,
O
,
N
};
std
::
vector
<
ck
::
index_t
>
v_gs_os_ns_strides
=
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G
1
*
O
,
O
,
1
,
G
1
*
O
}
// V layout [G0, N, G
1
, O]
:
std
::
vector
<
ck
::
index_t
>
{
G
1
*
N
*
O
,
N
*
O
,
1
,
O
};
// V layout [G0, G
1
, N, O]
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G
2
*
O
,
O
,
1
,
G
2
*
O
}
// V layout [G0, N, G
2
, O]
:
std
::
vector
<
ck
::
index_t
>
{
G
2
*
N
*
O
,
N
*
O
,
1
,
O
};
// V layout [G0, G
2
, N, O]
std
::
vector
<
ck
::
index_t
>
y_gs_ms_os_lengths
{
G0
,
G1
,
M
,
O
};
std
::
vector
<
ck
::
index_t
>
y_gs_ms_os_strides
=
...
...
@@ -442,6 +445,17 @@ int run(int argc, char* argv[])
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
N
,
N
,
G1
*
N
,
1
}
// Z layout [G0, M, G1, N]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
N
,
M
*
N
,
N
,
1
};
// Z layout [G0, G1, M, N]
std
::
vector
<
ck
::
index_t
>
kgrad_gs_ns_ks_lengths
{
G0
,
G1
,
N
,
K
};
std
::
vector
<
ck
::
index_t
>
kgrad_gs_ns_ks_strides
=
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1
*
K
,
K
,
G1
*
K
,
1
}
// KGrad layout [G0, N, G1, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
N
*
K
,
N
*
K
,
K
,
1
};
// KGrad layout [G0, G1, N, K]
std
::
vector
<
ck
::
index_t
>
vgrad_gs_os_ns_lengths
{
G0
,
G1
,
O
,
N
};
std
::
vector
<
ck
::
index_t
>
vgrad_gs_os_ns_strides
=
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1
*
O
,
O
,
1
,
G1
*
O
}
// VGrad layout [G0, N, G1, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
N
*
O
,
N
*
O
,
1
,
O
};
// VGrad layout [G0, G1, N, O]
// The softmax stat log-sum-exp (LSE) is used to speed up softmax calculation in backward
// pass Pi = exp(Si) / sum(exp(S0) + exp(S1) + ...)
// = exp(Si) / exp(log(sum(exp() + ...)))
...
...
@@ -481,6 +495,10 @@ int run(int argc, char* argv[])
y_gs_ms_os_strides
,
lse_gs_ms_lengths
,
lse_gs_ms_strides
,
kgrad_gs_ns_ks_lengths
,
kgrad_gs_ns_ks_strides
,
vgrad_gs_os_ns_lengths
,
vgrad_gs_os_ns_strides
,
{},
// std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_lengths},
{},
// std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_strides},
{},
// std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_lengths},
...
...
@@ -510,6 +528,8 @@ int run(int argc, char* argv[])
Tensor
<
InputDataType
>
y_gs_ms_os
(
y_gs_ms_os_lengths
,
y_gs_ms_os_strides
);
Tensor
<
InputDataType
>
ygrad_gs_ms_os
(
y_gs_ms_os_lengths
,
y_gs_ms_os_strides
);
Tensor
<
LSEDataType
>
lse_gs_ms
(
lse_gs_ms_lengths
,
lse_gs_ms_strides
);
Tensor
<
OutputDataType
>
kgrad_gs_ns_ks
(
kgrad_gs_ns_ks_lengths
,
kgrad_gs_ns_ks_strides
);
Tensor
<
OutputDataType
>
vgrad_gs_os_ns
(
vgrad_gs_os_ns_lengths
,
vgrad_gs_os_ns_strides
);
if
(
i
<
4
)
{
std
::
cout
<<
"q_gs_ms_ks: "
<<
q_gs_ms_ks
.
mDesc
<<
std
::
endl
;
...
...
@@ -518,6 +538,8 @@ int run(int argc, char* argv[])
std
::
cout
<<
"v_gs_os_ns: "
<<
v_gs_os_ns
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"y_gs_ms_os: "
<<
y_gs_ms_os
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"lse_gs_ms_os: "
<<
lse_gs_ms
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"kgrad_gs_ns_ks: "
<<
kgrad_gs_ns_ks
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"vgrad_gs_os_ns: "
<<
vgrad_gs_os_ns
.
mDesc
<<
std
::
endl
;
}
z_fwd_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
0
});
z_bwd_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
0
});
...
...
@@ -598,11 +620,19 @@ int run(int argc, char* argv[])
q_gs_ms_ks
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
q_g_m_k
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
k_gs_ns_ks
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
k_g_n_k
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
k_g_n_k
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
]
/
G1
;
const
size_t
&
g1
=
idx
[
0
]
%
G1
;
const
size_t
&
g2
=
g1
/
h_ratio
;
self
(
idx
)
=
k_gs_ns_ks
(
g0
,
g2
,
idx
[
1
],
idx
[
2
]);
});
v_gs_os_ns
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
v_g_n_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
3
],
idx
[
2
])
=
self
(
idx
);
v_g_n_o
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
]
/
G1
;
const
size_t
&
g1
=
idx
[
0
]
%
G1
;
const
size_t
&
g2
=
g1
/
h_ratio
;
self
(
idx
)
=
v_gs_os_ns
(
g0
,
g2
,
idx
[
2
],
idx
[
1
]);
});
q_g_m_ks
.
push_back
(
q_g_m_k
);
...
...
@@ -624,6 +654,8 @@ int run(int argc, char* argv[])
z_bwd_tensors
.
push_back
(
z_bwd_gs_ms_ns
);
lse_tensors
.
push_back
(
lse_gs_ms
);
ygrad_tensors
.
push_back
(
ygrad_gs_ms_os
);
kgrad_tensors
.
push_back
(
kgrad_gs_ns_ks
);
vgrad_tensors
.
push_back
(
vgrad_gs_os_ns
);
q_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
InputDataType
)
*
q_gs_ms_ks
.
GetElementSpaceSize
()));
...
...
@@ -641,10 +673,10 @@ int run(int argc, char* argv[])
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
LSEDataType
)
*
lse_gs_ms
.
GetElementSpaceSize
()));
qgrad_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
OutputDataType
)
*
q_gs_ms_ks
.
GetElementSpaceSize
()));
kgrad_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
OutputDataType
)
*
k_gs_ns_ks
.
GetElementSpaceSize
()));
vgrad_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
OutputDataType
)
*
v_gs_os_ns
.
GetElementSpaceSize
()));
kgrad_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
OutputDataType
)
*
k
grad
_gs_ns_ks
.
GetElementSpaceSize
()));
vgrad_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
OutputDataType
)
*
v
grad
_gs_os_ns
.
GetElementSpaceSize
()));
ygrad_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
InputDataType
)
*
y_gs_ms_os
.
GetElementSpaceSize
()));
...
...
@@ -689,7 +721,8 @@ int run(int argc, char* argv[])
Scale
{
alpha
},
QKVElementOp
{},
YElementOp
{},
p_drop
,
// dropout ratio
p_drop
,
// dropout ratio
h_ratio
,
{
seed
,
offset
});
// dropout random seed and offset, offset should
// be at least the number of elements on a thread
...
...
@@ -737,6 +770,7 @@ int run(int argc, char* argv[])
QKVElementOp
{},
YElementOp
{},
p_drop
,
h_ratio
,
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
(
seed
,
offset
));
DeviceMem
problem_desc_workspace_bwd
(
gemm_bwd
.
GetWorkSpaceSize
(
&
argument_bwd
));
...
...
@@ -786,7 +820,8 @@ int run(int argc, char* argv[])
Scale
{
alpha
},
QKVElementOp
{},
YElementOp
{},
p_drop
,
// dropout ratio
p_drop
,
// dropout ratio
h_ratio
,
{
seed
,
offset
});
// dropout random seed and offset, offset should
// be at least the number of elements on a thread
...
...
@@ -826,6 +861,7 @@ int run(int argc, char* argv[])
QKVElementOp
{},
YElementOp
{},
p_drop
,
h_ratio
,
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
(
seed
,
offset
));
DeviceMem
problem_desc_workspace_bwd_verify
(
gemm_bwd
.
GetWorkSpaceSize
(
&
argument_bwd
));
gemm_bwd
.
SetWorkSpacePointer
(
&
argument_bwd
,
...
...
@@ -840,7 +876,7 @@ int run(int argc, char* argv[])
for
(
std
::
size_t
i
=
0
;
i
<
group_count
;
i
++
)
{
int
G1
=
v
_tensors
[
i
].
GetLengths
()[
1
];
int
G1
=
q
_tensors
[
i
].
GetLengths
()[
1
];
// copy z matirx data form device
z_fwd_tensors_device
[
i
]
->
FromDevice
(
z_fwd_tensors
[
i
].
mData
.
data
());
z_fwd_tensors
[
i
].
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
...
...
@@ -863,7 +899,7 @@ int run(int argc, char* argv[])
p_dropout_in_uint8_t
,
rp_dropout
);
int
G0
=
v
_tensors
[
i
].
GetLengths
()[
0
];
int
G0
=
q
_tensors
[
i
].
GetLengths
()[
0
];
int
O
=
v_tensors
[
i
].
GetLengths
()[
2
];
int
N
=
v_tensors
[
i
].
GetLengths
()[
3
];
int
M
=
q_tensors
[
i
].
GetLengths
()[
2
];
...
...
@@ -921,10 +957,10 @@ int run(int argc, char* argv[])
Tensor
<
OutputDataType
>
qgrad_gs_ms_ks_host_result
(
q_tensors
[
i
].
GetLengths
(),
q_tensors
[
i
].
GetStrides
());
Tensor
<
OutputDataType
>
kgrad_gs_ns_ks_host_result
(
k_tensors
[
i
].
GetLengths
(),
k_tensors
[
i
].
GetStrides
());
Tensor
<
OutputDataType
>
vgrad_gs_os_ns_host_result
(
v_tensors
[
i
].
GetLengths
(),
v_tensors
[
i
].
GetStrides
());
Tensor
<
OutputDataType
>
kgrad_gs_ns_ks_host_result
(
k
grad
_tensors
[
i
].
GetLengths
(),
k
grad
_tensors
[
i
].
GetStrides
());
Tensor
<
OutputDataType
>
vgrad_gs_os_ns_host_result
(
v
grad
_tensors
[
i
].
GetLengths
(),
v
grad
_tensors
[
i
].
GetStrides
());
Tensor
<
InputDataType
>
y_gs_ms_os_host_result
(
y_tensors
[
i
].
GetLengths
(),
y_tensors
[
i
].
GetStrides
());
Tensor
<
LSEDataType
>
lse_gs_ms_host_result
(
lse_tensors
[
i
].
GetLengths
(),
...
...
@@ -932,10 +968,10 @@ int run(int argc, char* argv[])
Tensor
<
OutputDataType
>
qgrad_gs_ms_ks_device_result
(
q_tensors
[
i
].
GetLengths
(),
q_tensors
[
i
].
GetStrides
());
Tensor
<
OutputDataType
>
kgrad_gs_ns_ks_device_result
(
k_tensors
[
i
].
GetLengths
(),
k_tensors
[
i
].
GetStrides
());
Tensor
<
OutputDataType
>
vgrad_gs_os_ns_device_result
(
v_tensors
[
i
].
GetLengths
(),
v_tensors
[
i
].
GetStrides
());
Tensor
<
OutputDataType
>
kgrad_gs_ns_ks_device_result
(
k
grad
_tensors
[
i
].
GetLengths
(),
k
grad
_tensors
[
i
].
GetStrides
());
Tensor
<
OutputDataType
>
vgrad_gs_os_ns_device_result
(
v
grad
_tensors
[
i
].
GetLengths
(),
v
grad
_tensors
[
i
].
GetStrides
());
Tensor
<
InputDataType
>
y_gs_ms_os_device_result
(
y_tensors
[
i
].
GetLengths
(),
y_tensors
[
i
].
GetStrides
());
Tensor
<
LSEDataType
>
lse_gs_ms_device_result
(
lse_tensors
[
i
].
GetLengths
(),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment