Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Lmdeploy
Commits
1f88baa5
Unverified
Commit
1f88baa5
authored
Jul 17, 2023
by
q.yao
Committed by
GitHub
Jul 17, 2023
Browse files
update log info (#131)
* update log info * format cuda utils
parent
db3b986b
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
121 additions
and
98 deletions
+121
-98
src/turbomind/utils/cublasMMWrapper.cc
src/turbomind/utils/cublasMMWrapper.cc
+25
-25
src/turbomind/utils/cuda_utils.cc
src/turbomind/utils/cuda_utils.cc
+2
-2
src/turbomind/utils/cuda_utils.h
src/turbomind/utils/cuda_utils.h
+6
-6
src/turbomind/utils/gemm_test/encoder_igemm_func.cc
src/turbomind/utils/gemm_test/encoder_igemm_func.cc
+7
-7
src/turbomind/utils/logger.cc
src/turbomind/utils/logger.cc
+1
-1
tests/unittests/test_attention_kernels.cu
tests/unittests/test_attention_kernels.cu
+80
-57
No files found.
src/turbomind/utils/cublasMMWrapper.cc
View file @
1f88baa5
...
@@ -647,11 +647,11 @@ void cublasMMWrapper::SpGemm(cublasOperation_t transa,
...
@@ -647,11 +647,11 @@ void cublasMMWrapper::SpGemm(cublasOperation_t transa,
void
*
C
)
void
*
C
)
{
{
if
(
Atype_
!=
CUDA_R_16F
||
Btype_
!=
CUDA_R_16F
||
Ctype_
!=
CUDA_R_16F
)
{
if
(
Atype_
!=
CUDA_R_16F
||
Btype_
!=
CUDA_R_16F
||
Ctype_
!=
CUDA_R_16F
)
{
throw
std
::
runtime_error
(
"
\n
[
F
T][ERROR] sparse GEMM only supports FP16 data type now."
);
throw
std
::
runtime_error
(
"
\n
[T
M
][ERROR] sparse GEMM only supports FP16 data type now."
);
}
}
static
bool
not_printed_fp32_accumulation_warning
=
true
;
static
bool
not_printed_fp32_accumulation_warning
=
true
;
if
(
computeType_
!=
CUDA_R_16F
&&
not_printed_fp32_accumulation_warning
)
{
if
(
computeType_
!=
CUDA_R_16F
&&
not_printed_fp32_accumulation_warning
)
{
printf
(
"[
F
T][WARNING] cublasMMWrapper sets to FP32 compute type, "
printf
(
"[T
M
][WARNING] cublasMMWrapper sets to FP32 compute type, "
"but sparse gemm will use FP16 compute type since cusparselt "
"but sparse gemm will use FP16 compute type since cusparselt "
"supports FP16 accumulation only.
\n
"
);
"supports FP16 accumulation only.
\n
"
);
not_printed_fp32_accumulation_warning
=
false
;
not_printed_fp32_accumulation_warning
=
false
;
...
...
src/turbomind/utils/cuda_utils.cc
View file @
1f88baa5
...
@@ -38,7 +38,7 @@ void print_to_file(const T* result, const int size, const char* file, cudaStream
...
@@ -38,7 +38,7 @@ void print_to_file(const T* result, const int size, const char* file, cudaStream
delete
[]
tmp
;
delete
[]
tmp
;
}
}
else
{
else
{
throw
std
::
runtime_error
(
std
::
string
(
"[
F
T][ERROR] Cannot open file: "
)
+
file
+
"
\n
"
);
throw
std
::
runtime_error
(
std
::
string
(
"[T
M
][ERROR] Cannot open file: "
)
+
file
+
"
\n
"
);
}
}
cudaDeviceSynchronize
();
cudaDeviceSynchronize
();
check_cuda_error
(
cudaGetLastError
());
check_cuda_error
(
cudaGetLastError
());
...
@@ -81,7 +81,7 @@ void print_abs_mean(const T* buf, uint size, cudaStream_t stream, std::string na
...
@@ -81,7 +81,7 @@ void print_abs_mean(const T* buf, uint size, cudaStream_t stream, std::string na
}
}
max_val
=
max_val
>
abs
(
float
(
h_tmp
[
i
]))
?
max_val
:
abs
(
float
(
h_tmp
[
i
]));
max_val
=
max_val
>
abs
(
float
(
h_tmp
[
i
]))
?
max_val
:
abs
(
float
(
h_tmp
[
i
]));
}
}
printf
(
"[INFO
][FT
] %20s size: %u, abs mean: %f, abs sum: %f, abs max: %f, find inf: %s"
,
printf
(
"[
TM][
INFO] %20s size: %u, abs mean: %f, abs sum: %f, abs max: %f, find inf: %s"
,
name
.
c_str
(),
name
.
c_str
(),
size
,
size
,
sum
/
size
,
sum
/
size
,
...
...
src/turbomind/utils/cuda_utils.h
View file @
1f88baa5
...
@@ -119,7 +119,7 @@ template<typename T>
...
@@ -119,7 +119,7 @@ template<typename T>
void
check
(
T
result
,
char
const
*
const
func
,
const
char
*
const
file
,
int
const
line
)
void
check
(
T
result
,
char
const
*
const
func
,
const
char
*
const
file
,
int
const
line
)
{
{
if
(
result
)
{
if
(
result
)
{
throw
std
::
runtime_error
(
std
::
string
(
"[
F
T][ERROR] CUDA runtime error: "
)
+
(
_cudaGetErrorEnum
(
result
))
+
" "
throw
std
::
runtime_error
(
std
::
string
(
"[T
M
][ERROR] CUDA runtime error: "
)
+
(
_cudaGetErrorEnum
(
result
))
+
" "
+
file
+
":"
+
std
::
to_string
(
line
)
+
"
\n
"
);
+
file
+
":"
+
std
::
to_string
(
line
)
+
"
\n
"
);
}
}
}
}
...
@@ -137,7 +137,7 @@ inline void syncAndCheck(const char* const file, int const line)
...
@@ -137,7 +137,7 @@ inline void syncAndCheck(const char* const file, int const line)
cudaDeviceSynchronize
();
cudaDeviceSynchronize
();
cudaError_t
result
=
cudaGetLastError
();
cudaError_t
result
=
cudaGetLastError
();
if
(
result
)
{
if
(
result
)
{
throw
std
::
runtime_error
(
std
::
string
(
"[
F
T][ERROR] CUDA runtime error: "
)
+
(
_cudaGetErrorEnum
(
result
))
throw
std
::
runtime_error
(
std
::
string
(
"[T
M
][ERROR] CUDA runtime error: "
)
+
(
_cudaGetErrorEnum
(
result
))
+
" "
+
file
+
":"
+
std
::
to_string
(
line
)
+
"
\n
"
);
+
" "
+
file
+
":"
+
std
::
to_string
(
line
)
+
"
\n
"
);
}
}
TM_LOG_DEBUG
(
fmtstr
(
"run syncAndCheck at %s:%d"
,
file
,
line
));
TM_LOG_DEBUG
(
fmtstr
(
"run syncAndCheck at %s:%d"
,
file
,
line
));
...
@@ -148,7 +148,7 @@ inline void syncAndCheck(const char* const file, int const line)
...
@@ -148,7 +148,7 @@ inline void syncAndCheck(const char* const file, int const line)
cudaDeviceSynchronize
();
cudaDeviceSynchronize
();
cudaError_t
result
=
cudaGetLastError
();
cudaError_t
result
=
cudaGetLastError
();
if
(
result
)
{
if
(
result
)
{
throw
std
::
runtime_error
(
std
::
string
(
"[
F
T][ERROR] CUDA runtime error: "
)
+
(
_cudaGetErrorEnum
(
result
))
+
" "
throw
std
::
runtime_error
(
std
::
string
(
"[T
M
][ERROR] CUDA runtime error: "
)
+
(
_cudaGetErrorEnum
(
result
))
+
" "
+
file
+
":"
+
std
::
to_string
(
line
)
+
"
\n
"
);
+
file
+
":"
+
std
::
to_string
(
line
)
+
"
\n
"
);
}
}
#endif
#endif
...
@@ -194,12 +194,12 @@ void check_abs_mean_val(const T* result, const int size);
...
@@ -194,12 +194,12 @@ void check_abs_mean_val(const T* result, const int size);
#define PRINT_FUNC_NAME_() \
#define PRINT_FUNC_NAME_() \
do { \
do { \
std::cout << "[
F
T][CALL] " << __FUNCTION__ << " " << std::endl; \
std::cout << "[T
M
][CALL] " << __FUNCTION__ << " " << std::endl; \
} while (0)
} while (0)
[[
noreturn
]]
inline
void
throwRuntimeError
(
const
char
*
const
file
,
int
const
line
,
std
::
string
const
&
info
=
""
)
[[
noreturn
]]
inline
void
throwRuntimeError
(
const
char
*
const
file
,
int
const
line
,
std
::
string
const
&
info
=
""
)
{
{
throw
std
::
runtime_error
(
std
::
string
(
"[
F
T][ERROR] "
)
+
info
+
" Assertion fail: "
+
file
+
":"
throw
std
::
runtime_error
(
std
::
string
(
"[T
M
][ERROR] "
)
+
info
+
" Assertion fail: "
+
file
+
":"
+
std
::
to_string
(
line
)
+
"
\n
"
);
+
std
::
to_string
(
line
)
+
"
\n
"
);
}
}
...
@@ -226,7 +226,7 @@ inline void myAssert(bool result, const char* const file, int const line, std::s
...
@@ -226,7 +226,7 @@ inline void myAssert(bool result, const char* const file, int const line, std::s
{ \
{ \
cusparseStatus_t status = (func); \
cusparseStatus_t status = (func); \
if (status != CUSPARSE_STATUS_SUCCESS) { \
if (status != CUSPARSE_STATUS_SUCCESS) { \
throw std::runtime_error(std::string("[
F
T][ERROR] CUSPARSE API failed at line ") \
throw std::runtime_error(std::string("[T
M
][ERROR] CUSPARSE API failed at line ") \
+ std::to_string(__LINE__) + " in file " + __FILE__ + ": " \
+ std::to_string(__LINE__) + " in file " + __FILE__ + ": " \
+ cusparseGetErrorString(status) + " " + std::to_string(status)); \
+ cusparseGetErrorString(status) + " " + std::to_string(status)); \
} \
} \
...
...
src/turbomind/utils/gemm_test/encoder_igemm_func.cc
View file @
1f88baa5
...
@@ -1252,7 +1252,7 @@ int generate_encoder_igemm_config(
...
@@ -1252,7 +1252,7 @@ int generate_encoder_igemm_config(
cudaDeviceSynchronize
();
cudaDeviceSynchronize
();
cudaError_t
result
=
cudaGetLastError
();
cudaError_t
result
=
cudaGetLastError
();
if
(
result
)
{
if
(
result
)
{
throw
std
::
runtime_error
(
std
::
string
(
"[
F
T][ERROR] CUDA runtime error: "
));
throw
std
::
runtime_error
(
std
::
string
(
"[T
M
][ERROR] CUDA runtime error: "
));
}
}
float
exec_time
=
99999.0
f
;
float
exec_time
=
99999.0
f
;
...
...
src/turbomind/utils/logger.cc
View file @
1f88baa5
...
@@ -47,7 +47,7 @@ Logger::Logger()
...
@@ -47,7 +47,7 @@ Logger::Logger()
}
}
else
{
else
{
fprintf
(
stderr
,
fprintf
(
stderr
,
"[
F
T][WARNING] Invalid logger level TM_LOG_LEVEL=%s. "
"[T
M
][WARNING] Invalid logger level TM_LOG_LEVEL=%s. "
"Ignore the environment variable and use a default "
"Ignore the environment variable and use a default "
"logging level.
\n
"
,
"logging level.
\n
"
,
level_name
);
level_name
);
...
...
tests/unittests/test_attention_kernels.cu
View file @
1f88baa5
...
@@ -14,13 +14,13 @@
...
@@ -14,13 +14,13 @@
* limitations under the License.
* limitations under the License.
*/
*/
#include "tests/unittests/gtest_utils.h"
#include "src/turbomind/kernels/gen_relative_pos_bias.h"
#include "src/turbomind/kernels/gen_relative_pos_bias.h"
#include "src/turbomind/kernels/gpt_kernels.h"
#include "src/turbomind/kernels/gpt_kernels.h"
#include "src/turbomind/kernels/unfused_attention_kernels.h"
#include "src/turbomind/kernels/unfused_attention_kernels.h"
#include "src/turbomind/utils/Tensor.h"
#include "src/turbomind/utils/memory_utils.h"
#include "src/turbomind/utils/memory_utils.h"
#include "src/turbomind/utils/nccl_utils.h"
#include "src/turbomind/utils/nccl_utils.h"
#include "
src/turbomind/utils/Tensor
.h"
#include "
tests/unittests/gtest_utils
.h"
#include <curand.h>
#include <curand.h>
#include <sstream>
#include <sstream>
...
@@ -47,12 +47,13 @@ struct AttentionKernelTestParam {
...
@@ -47,12 +47,13 @@ struct AttentionKernelTestParam {
namespace
utils
{
namespace
utils
{
#define CHECK_CURAND(cmd) do { \
#define CHECK_CURAND(cmd) \
do { \
curandStatus_t err = cmd; \
curandStatus_t err = cmd; \
if (err != CURAND_STATUS_SUCCESS) { \
if (err != CURAND_STATUS_SUCCESS) { \
throw std::runtime_error(
\
throw std::runtime_error(
fmtstr("[TM][ERROR] curand runtime error: %d", err));
\
fmtstr("[FT][ERROR] curand runtime error: %d", err));
\
}
\
}
}
while(0)
\
} while
(0)
__global__
void
convert_and_copy
(
half
*
dst
,
const
float
*
src
,
const
size_t
size
)
__global__
void
convert_and_copy
(
half
*
dst
,
const
float
*
src
,
const
size_t
size
)
{
{
...
@@ -193,7 +194,7 @@ void computeQkSoftmax(T* attn_score,
...
@@ -193,7 +194,7 @@ void computeQkSoftmax(T* attn_score,
}
}
template
<
typename
T
>
template
<
typename
T
>
class
AttentionKernelTest
:
public
FtTestBase
{
class
AttentionKernelTest
:
public
FtTestBase
{
private:
private:
using
FtTestBase
::
stream
;
using
FtTestBase
::
stream
;
...
@@ -252,10 +253,11 @@ public:
...
@@ -252,10 +253,11 @@ public:
FtTestBase
::
TearDown
();
FtTestBase
::
TearDown
();
}
}
void
runTestMaskedSoftmax
(
AttentionKernelTestParam
param
,
bool
is_benchmark
=
false
)
{
void
runTestMaskedSoftmax
(
AttentionKernelTestParam
param
,
bool
is_benchmark
=
false
)
{
DataType
dtype
=
getTensorType
<
T
>
();
DataType
dtype
=
getTensorType
<
T
>
();
std
::
vector
<
size_t
>
qk_shape
{
param
.
batch_size
,
param
.
head_num
,
param
.
q_length
,
param
.
k_length
};
std
::
vector
<
size_t
>
qk_shape
{
param
.
batch_size
,
param
.
head_num
,
param
.
q_length
,
param
.
k_length
};
bool
use_fp32_qk
=
param
.
use_fp32_qk_buf
&&
dtype
!=
TYPE_FP32
;
bool
use_fp32_qk
=
param
.
use_fp32_qk_buf
&&
dtype
!=
TYPE_FP32
;
...
@@ -332,10 +334,11 @@ public:
...
@@ -332,10 +334,11 @@ public:
}
}
}
}
void
runTestAlibiMaskedSoftmax
(
AttentionKernelTestParam
param
,
bool
is_benchmark
=
false
)
{
void
runTestAlibiMaskedSoftmax
(
AttentionKernelTestParam
param
,
bool
is_benchmark
=
false
)
{
DataType
dtype
=
getTensorType
<
T
>
();
DataType
dtype
=
getTensorType
<
T
>
();
std
::
vector
<
size_t
>
qk_shape
{
param
.
batch_size
,
param
.
head_num
,
param
.
q_length
,
param
.
k_length
};
std
::
vector
<
size_t
>
qk_shape
{
param
.
batch_size
,
param
.
head_num
,
param
.
q_length
,
param
.
k_length
};
bool
use_fp32_qk
=
param
.
use_fp32_qk_buf
&&
dtype
!=
TYPE_FP32
;
bool
use_fp32_qk
=
param
.
use_fp32_qk_buf
&&
dtype
!=
TYPE_FP32
;
...
@@ -355,8 +358,8 @@ public:
...
@@ -355,8 +358,8 @@ public:
sync_check_cuda_error
();
sync_check_cuda_error
();
Tensor
h_alibi_slopes
=
createTensor
(
MEMORY_CPU
,
dtype
,
{
param
.
head_num
});
Tensor
h_alibi_slopes
=
createTensor
(
MEMORY_CPU
,
dtype
,
{
param
.
head_num
});
Tensor
h_alibi_bias
=
is_benchmark
?
Tensor
()
:
Tensor
h_alibi_bias
=
createTensor
(
MEMORY_CPU
,
dtype
,
{
param
.
head_num
,
param
.
q_length
,
param
.
k_length
});
is_benchmark
?
Tensor
()
:
createTensor
(
MEMORY_CPU
,
dtype
,
{
param
.
head_num
,
param
.
q_length
,
param
.
k_length
});
// The nearest power of 2 equal to / smaller than num_heads followed by HF's implementation.
// The nearest power of 2 equal to / smaller than num_heads followed by HF's implementation.
T
*
alibi_slope_ptr
=
h_alibi_slopes
.
getPtr
<
T
>
();
T
*
alibi_slope_ptr
=
h_alibi_slopes
.
getPtr
<
T
>
();
int
num_heads_pow2
=
utils
::
pow2_rounddown
(
param
.
head_num
);
int
num_heads_pow2
=
utils
::
pow2_rounddown
(
param
.
head_num
);
...
@@ -364,7 +367,8 @@ public:
...
@@ -364,7 +367,8 @@ public:
// The slope of linear bias of the attention head
// The slope of linear bias of the attention head
if
(
h
<
num_heads_pow2
)
{
if
(
h
<
num_heads_pow2
)
{
alibi_slope_ptr
[
h
]
=
static_cast
<
T
>
(
powf
(
powf
(
0.5
f
,
powf
(
0.5
f
,
log2f
(
num_heads_pow2
)
-
3.
f
)),
h
+
1
));
alibi_slope_ptr
[
h
]
=
static_cast
<
T
>
(
powf
(
powf
(
0.5
f
,
powf
(
0.5
f
,
log2f
(
num_heads_pow2
)
-
3.
f
)),
h
+
1
));
}
else
{
}
else
{
alibi_slope_ptr
[
h
]
=
static_cast
<
T
>
(
alibi_slope_ptr
[
h
]
=
static_cast
<
T
>
(
powf
(
powf
(
0.5
f
,
powf
(
0.5
f
,
log2f
(
num_heads_pow2
<<
1
)
-
3.
f
)),
(
h
-
num_heads_pow2
)
*
2
+
1
));
powf
(
powf
(
0.5
f
,
powf
(
0.5
f
,
log2f
(
num_heads_pow2
<<
1
)
-
3.
f
)),
(
h
-
num_heads_pow2
)
*
2
+
1
));
}
}
...
@@ -448,87 +452,106 @@ public:
...
@@ -448,87 +452,106 @@ public:
TYPED_TEST_SUITE
(
AttentionKernelTest
,
SupportTypes
);
TYPED_TEST_SUITE
(
AttentionKernelTest
,
SupportTypes
);
TYPED_TEST
(
AttentionKernelTest
,
MaskedSoftmax_NoPrompt
)
{
TYPED_TEST
(
AttentionKernelTest
,
MaskedSoftmax_NoPrompt
)
{
this
->
runTestMaskedSoftmax
({
1
,
12
,
12
,
1
,
32
,
false
,
0
,
false
});
this
->
runTestMaskedSoftmax
({
1
,
12
,
12
,
1
,
32
,
false
,
0
,
false
});
}
}
TYPED_TEST
(
AttentionKernelTest
,
MaskedSoftmax_NoPrompt2
)
{
TYPED_TEST
(
AttentionKernelTest
,
MaskedSoftmax_NoPrompt2
)
{
// q_length is not multiple of 4.
// q_length is not multiple of 4.
this
->
runTestMaskedSoftmax
({
1
,
11
,
11
,
4
,
32
,
false
,
0
,
false
});
this
->
runTestMaskedSoftmax
({
1
,
11
,
11
,
4
,
32
,
false
,
0
,
false
});
}
}
TYPED_TEST
(
AttentionKernelTest
,
MaskedSoftmax_HasPrompt
)
{
TYPED_TEST
(
AttentionKernelTest
,
MaskedSoftmax_HasPrompt
)
{
this
->
runTestMaskedSoftmax
({
1
,
12
,
24
,
2
,
32
,
false
,
0
,
false
});
this
->
runTestMaskedSoftmax
({
1
,
12
,
24
,
2
,
32
,
false
,
0
,
false
});
}
}
TYPED_TEST
(
AttentionKernelTest
,
MaskedSoftmax_HasPrompt2
)
{
TYPED_TEST
(
AttentionKernelTest
,
MaskedSoftmax_HasPrompt2
)
{
this
->
runTestMaskedSoftmax
({
1
,
11
,
24
,
2
,
32
,
false
,
0
,
false
});
this
->
runTestMaskedSoftmax
({
1
,
11
,
24
,
2
,
32
,
false
,
0
,
false
});
}
}
TYPED_TEST
(
AttentionKernelTest
,
MaskedSoftmax_LongSequence1024
)
{
TYPED_TEST
(
AttentionKernelTest
,
MaskedSoftmax_LongSequence1024
)
{
this
->
runTestMaskedSoftmax
({
1
,
12
,
1024
,
2
,
32
,
false
,
0
,
false
});
this
->
runTestMaskedSoftmax
({
1
,
12
,
1024
,
2
,
32
,
false
,
0
,
false
});
}
}
TYPED_TEST
(
AttentionKernelTest
,
MaskedSoftmax_LongSequence2048
)
{
TYPED_TEST
(
AttentionKernelTest
,
MaskedSoftmax_LongSequence2048
)
{
this
->
runTestMaskedSoftmax
({
1
,
12
,
2048
,
2
,
32
,
false
,
0
,
false
});
this
->
runTestMaskedSoftmax
({
1
,
12
,
2048
,
2
,
32
,
false
,
0
,
false
});
}
}
TYPED_TEST
(
AttentionKernelTest
,
MaskedSoftmax_LongSequence3072
)
{
TYPED_TEST
(
AttentionKernelTest
,
MaskedSoftmax_LongSequence3072
)
{
this
->
runTestMaskedSoftmax
({
1
,
12
,
3072
,
2
,
32
,
false
,
0
,
false
});
this
->
runTestMaskedSoftmax
({
1
,
12
,
3072
,
2
,
32
,
false
,
0
,
false
});
}
}
TYPED_TEST
(
AttentionKernelTest
,
MaskedSoftmax_LongSequence4096
)
{
TYPED_TEST
(
AttentionKernelTest
,
MaskedSoftmax_LongSequence4096
)
{
this
->
runTestMaskedSoftmax
({
1
,
12
,
4096
,
2
,
32
,
false
,
0
,
false
});
this
->
runTestMaskedSoftmax
({
1
,
12
,
4096
,
2
,
32
,
false
,
0
,
false
});
}
}
TYPED_TEST
(
AttentionKernelTest
,
Benchmark_MaskedSoftmax_LongSequence1024
)
{
TYPED_TEST
(
AttentionKernelTest
,
Benchmark_MaskedSoftmax_LongSequence1024
)
{
// Assume the bloom 176B model with 8 TP.
// Assume the bloom 176B model with 8 TP.
this
->
runTestMaskedSoftmax
({
8
,
1024
,
1024
,
14
,
128
,
false
,
0
,
false
,
true
},
true
);
this
->
runTestMaskedSoftmax
({
8
,
1024
,
1024
,
14
,
128
,
false
,
0
,
false
,
true
},
true
);
}
}
TYPED_TEST
(
AttentionKernelTest
,
Benchmark_MaskedSoftmax_LongSequence2048
)
{
TYPED_TEST
(
AttentionKernelTest
,
Benchmark_MaskedSoftmax_LongSequence2048
)
{
// Assume the bloom 176B model with 8 TP.
// Assume the bloom 176B model with 8 TP.
this
->
runTestMaskedSoftmax
({
8
,
2048
,
2048
,
14
,
128
,
false
,
0
,
false
,
true
},
true
);
this
->
runTestMaskedSoftmax
({
8
,
2048
,
2048
,
14
,
128
,
false
,
0
,
false
,
true
},
true
);
}
}
TYPED_TEST
(
AttentionKernelTest
,
Benchmark_MaskedSoftmax_LongSequence4096
)
{
TYPED_TEST
(
AttentionKernelTest
,
Benchmark_MaskedSoftmax_LongSequence4096
)
{
// Assume the bloom 176B model with 8 TP.
// Assume the bloom 176B model with 8 TP.
this
->
runTestMaskedSoftmax
({
8
,
4096
,
4096
,
14
,
128
,
false
,
0
,
false
,
true
},
true
);
this
->
runTestMaskedSoftmax
({
8
,
4096
,
4096
,
14
,
128
,
false
,
0
,
false
,
true
},
true
);
}
}
TYPED_TEST
(
AttentionKernelTest
,
AlibiMaskedSoftmax_ShortSequence1
)
{
TYPED_TEST
(
AttentionKernelTest
,
AlibiMaskedSoftmax_ShortSequence1
)
{
this
->
runTestAlibiMaskedSoftmax
({
1
,
12
,
12
,
4
,
32
,
false
,
0
,
false
});
this
->
runTestAlibiMaskedSoftmax
({
1
,
12
,
12
,
4
,
32
,
false
,
0
,
false
});
}
}
TYPED_TEST
(
AttentionKernelTest
,
AlibiMaskedSoftmax_ShortSequence2
)
{
TYPED_TEST
(
AttentionKernelTest
,
AlibiMaskedSoftmax_ShortSequence2
)
{
// q_length is not multiple of 4.
// q_length is not multiple of 4.
this
->
runTestAlibiMaskedSoftmax
({
1
,
11
,
11
,
4
,
32
,
false
,
0
,
false
});
this
->
runTestAlibiMaskedSoftmax
({
1
,
11
,
11
,
4
,
32
,
false
,
0
,
false
});
}
}
TYPED_TEST
(
AttentionKernelTest
,
AlibiMaskedSoftmax_ShortSequence_HasPrompt1
)
{
TYPED_TEST
(
AttentionKernelTest
,
AlibiMaskedSoftmax_ShortSequence_HasPrompt1
)
{
this
->
runTestAlibiMaskedSoftmax
({
1
,
12
,
20
,
4
,
32
,
false
,
0
,
false
});
this
->
runTestAlibiMaskedSoftmax
({
1
,
12
,
20
,
4
,
32
,
false
,
0
,
false
});
}
}
TYPED_TEST
(
AttentionKernelTest
,
AlibiMaskedSoftmax_ShortSequence_HasPrompt2
)
{
TYPED_TEST
(
AttentionKernelTest
,
AlibiMaskedSoftmax_ShortSequence_HasPrompt2
)
{
// q_length is not multiple of 4.
// q_length is not multiple of 4.
this
->
runTestAlibiMaskedSoftmax
({
1
,
11
,
20
,
4
,
32
,
false
,
0
,
false
});
this
->
runTestAlibiMaskedSoftmax
({
1
,
11
,
20
,
4
,
32
,
false
,
0
,
false
});
}
}
// Tests for long sentence generation. Assume the bloom 176B model with 8 TP.
// Tests for long sentence generation. Assume the bloom 176B model with 8 TP.
TYPED_TEST
(
AttentionKernelTest
,
Benchmark_AlibiMaskedSoftmax_LongSequence1024
)
{
TYPED_TEST
(
AttentionKernelTest
,
Benchmark_AlibiMaskedSoftmax_LongSequence1024
)
{
this
->
runTestAlibiMaskedSoftmax
({
8
,
1024
,
1024
,
14
,
128
,
false
,
0
,
false
,
true
},
true
);
this
->
runTestAlibiMaskedSoftmax
({
8
,
1024
,
1024
,
14
,
128
,
false
,
0
,
false
,
true
},
true
);
}
}
TYPED_TEST
(
AttentionKernelTest
,
Benchmark_AlibiMaskedSoftmax_LongSequence2048
)
{
TYPED_TEST
(
AttentionKernelTest
,
Benchmark_AlibiMaskedSoftmax_LongSequence2048
)
{
this
->
runTestAlibiMaskedSoftmax
({
8
,
2048
,
2048
,
14
,
128
,
false
,
0
,
false
,
true
},
true
);
this
->
runTestAlibiMaskedSoftmax
({
8
,
2048
,
2048
,
14
,
128
,
false
,
0
,
false
,
true
},
true
);
}
}
TYPED_TEST
(
AttentionKernelTest
,
Benchmark_AlibiMaskedSoftmax_LongSequence3072
)
{
TYPED_TEST
(
AttentionKernelTest
,
Benchmark_AlibiMaskedSoftmax_LongSequence3072
)
{
this
->
runTestAlibiMaskedSoftmax
({
8
,
3072
,
3072
,
14
,
128
,
false
,
0
,
false
,
true
},
true
);
this
->
runTestAlibiMaskedSoftmax
({
8
,
3072
,
3072
,
14
,
128
,
false
,
0
,
false
,
true
},
true
);
}
}
TYPED_TEST
(
AttentionKernelTest
,
Benchmark_AlibiMaskedSoftmax_LongSequence4096
)
{
TYPED_TEST
(
AttentionKernelTest
,
Benchmark_AlibiMaskedSoftmax_LongSequence4096
)
{
this
->
runTestAlibiMaskedSoftmax
({
4
,
4096
,
4096
,
14
,
128
,
false
,
0
,
false
,
true
},
true
);
this
->
runTestAlibiMaskedSoftmax
({
4
,
4096
,
4096
,
14
,
128
,
false
,
0
,
false
,
true
},
true
);
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment