Commit 3253240a authored by xiabo's avatar xiabo
Browse files

对应官方最新版本0.1.0主要增加page Attention

修改测试用例
parent a8ce8d27
......@@ -17,11 +17,13 @@
include(FetchContent)
FetchContent_Declare(
googletest
GIT_REPOSITORY https://github.com/google/googletest.git
GIT_TAG release-1.12.1
URL ../../../3rdparty/googletest-release-1.12.1
#GIT_REPOSITORY https://github.com/google/googletest.git
#GIT_TAG release-1.12.1
)
find_package(CUDAToolkit REQUIRED)
# find_package(CUDAToolkit REQUIRED)
find_package(CUDA REQUIRED)
if (NOT MSVC)
add_definitions(-DTORCH_CUDA=1)
......@@ -31,12 +33,14 @@ endif()
set(gtest_force_shared_crt ON CACHE BOOL "" FORCE)
FetchContent_MakeAvailable(googletest)
include_directories(../../../3rdparty/googletest-release-1.12.1/googletest/include)
add_executable(unittest
test_attention_kernels.cu
test_logprob_kernels.cu
test_penalty_kernels.cu
test_sampling_kernels.cu
test_sampling_layer.cu
# test_sampling_layer.cu
test_tensor.cu)
# automatic discovery of unit tests
......@@ -46,38 +50,38 @@ target_compile_features(unittest PRIVATE cxx_std_14)
# Sorted by alphabetical order of test name.
target_link_libraries( # Libs for test_attention_kernels
unittest PUBLIC
CUDA::cudart CUDA::curand
cudart curand
gpt_kernels gtest memory_utils tensor unfused_attention_kernels cuda_utils logger)
target_link_libraries( # Libs for test_logprob_kernels
unittest PUBLIC
CUDA::cudart
cudart
logprob_kernels memory_utils cuda_utils logger)
target_link_libraries( # Libs for test_penalty_kernels
unittest PUBLIC
CUDA::cublas CUDA::cublasLt CUDA::cudart
cublas cudart
sampling_penalty_kernels memory_utils cuda_utils logger)
target_link_libraries( # Libs for test_sampling_kernel
unittest PUBLIC
CUDA::cudart
cudart
sampling_topk_kernels sampling_topp_kernels memory_utils tensor cuda_utils logger)
target_link_libraries( # Libs for test_sampling_layer
unittest PUBLIC
CUDA::cublas CUDA::cublasLt CUDA::cudart
cublas cudart
cublasMMWrapper memory_utils
DynamicDecodeLayer TopKSamplingLayer TopPSamplingLayer tensor cuda_utils logger)
target_link_libraries( # Libs for test_tensor
unittest PUBLIC tensor cuda_utils logger)
unittest PUBLIC -lrocblas tensor cuda_utils logger)
remove_definitions(-DTORCH_CUDA=1)
add_executable(test_gemm test_gemm.cu)
target_link_libraries(test_gemm PUBLIC CUDA::cublas CUDA::cudart CUDA::curand gemm cublasMMWrapper tensor cuda_utils logger)
target_link_libraries(test_gemm PUBLIC -lrocblas cublas cudart curand gemm cublasMMWrapper tensor cuda_utils logger)
add_executable(test_gpt_kernels test_gpt_kernels.cu)
target_link_libraries(test_gpt_kernels PUBLIC
gpt_kernels memory_utils tensor cuda_utils logger)
add_executable(test_context_attention_layer test_context_attention_layer.cu)
target_link_libraries(test_context_attention_layer PUBLIC
Llama CUDA::cublas CUDA::cublasLt CUDA::cudart
unfused_attention_kernels
memory_utils tensor cublasMMWrapper cuda_utils logger)
#add_executable(test_context_attention_layer test_context_attention_layer.cu)
#target_link_libraries(test_context_attention_layer PUBLIC
# Llama cublas cudart
# unfused_attention_kernels
# memory_utils tensor cublasMMWrapper cuda_utils logger)
This diff is collapsed.
......@@ -446,10 +446,10 @@ TYPED_TEST(TopKSamplingKernelTest, BatchCorrectnessLargeK63)
this->runBatchTest({8, 4000, 1, 63, 1.0f, 8});
};
TYPED_TEST(TopKSamplingKernelTest, BatchCorrectnessLargeK1024)
{
this->runBatchTest({8, 4000, 1, 1024, 0.0f, 8});
};
// TYPED_TEST(TopKSamplingKernelTest, BatchCorrectnessLargeK1024)
// {
// this->runBatchTest({8, 4000, 1, 1024, 0.0f, 8});
// };
TYPED_TEST(TopKSamplingKernelTest, BatchCorrectnessTopKTopP)
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment