Commit aee06365 authored by Po Yen, Chen's avatar Po Yen, Chen
Browse files

Remove redundant args setting

parent 18d235df
......@@ -53,7 +53,9 @@ target_include_directories(${EXAMPLE_NAME}
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include # ignore compilation warnings in kernel implementation
)
target_link_libraries(${EXAMPLE_NAME} "${TORCH_LIBRARIES}")
target_compile_definitions(${EXAMPLE_NAME} PRIVATE USE_ROCM)
target_compile_definitions(${EXAMPLE_NAME}
PRIVATE USE_ROCM
)
target_compile_options(${EXAMPLE_NAME}
PRIVATE ${TORCH_CXX_FLAGS}
)
\ No newline at end of file
......@@ -14,6 +14,7 @@
* limitations under the License.
*/
#include <torch/torch.h>
#include <c10/cuda/CUDAGuard.h>
#include <hip/hip_runtime.h>
......@@ -42,7 +43,6 @@ void paged_attention(
const c10::optional<torch::Tensor>& fp8_out_scale,
int64_t partition_size)
{
native::paged_attention_traits traits;
traits.q_type = (query.dtype() == at::ScalarType::Half ? native::ScalarType::Half
......@@ -51,8 +51,6 @@ void paged_attention(
native::paged_attention_args args;
args.head_size = query.size(2);
args.num_seqs = query.size(0);
args.num_heads = query.size(1);
args.head_size = query.size(2);
......@@ -88,10 +86,8 @@ void paged_attention(
args.k_scale = k_scale;
args.v_scale = v_scale;
hipStream_t stream = nullptr;
HIP_CHECK_ERROR(hipStreamCreate(&stream));
const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
native::paged_attention(traits, args, stream);
HIP_CHECK_ERROR(hipStreamDestroy(stream));
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment