"...composable_kernel.git" did not exist on "7353ec0c25468d754ad5dd786e979a3bbade0a47"
Commit aee06365 authored by Po Yen, Chen's avatar Po Yen, Chen
Browse files

Remove redundant args setting

parent 18d235df
...@@ -53,7 +53,9 @@ target_include_directories(${EXAMPLE_NAME} ...@@ -53,7 +53,9 @@ target_include_directories(${EXAMPLE_NAME}
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include # ignore compilation warnings in kernel implementation PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include # ignore compilation warnings in kernel implementation
) )
target_link_libraries(${EXAMPLE_NAME} "${TORCH_LIBRARIES}") target_link_libraries(${EXAMPLE_NAME} "${TORCH_LIBRARIES}")
target_compile_definitions(${EXAMPLE_NAME} PRIVATE USE_ROCM) target_compile_definitions(${EXAMPLE_NAME}
PRIVATE USE_ROCM
)
target_compile_options(${EXAMPLE_NAME} target_compile_options(${EXAMPLE_NAME}
PRIVATE ${TORCH_CXX_FLAGS} PRIVATE ${TORCH_CXX_FLAGS}
) )
\ No newline at end of file
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
* limitations under the License. * limitations under the License.
*/ */
#include <torch/torch.h> #include <torch/torch.h>
#include <c10/cuda/CUDAGuard.h>
#include <hip/hip_runtime.h> #include <hip/hip_runtime.h>
...@@ -42,7 +43,6 @@ void paged_attention( ...@@ -42,7 +43,6 @@ void paged_attention(
const c10::optional<torch::Tensor>& fp8_out_scale, const c10::optional<torch::Tensor>& fp8_out_scale,
int64_t partition_size) int64_t partition_size)
{ {
native::paged_attention_traits traits; native::paged_attention_traits traits;
traits.q_type = (query.dtype() == at::ScalarType::Half ? native::ScalarType::Half traits.q_type = (query.dtype() == at::ScalarType::Half ? native::ScalarType::Half
...@@ -51,8 +51,6 @@ void paged_attention( ...@@ -51,8 +51,6 @@ void paged_attention(
native::paged_attention_args args; native::paged_attention_args args;
args.head_size = query.size(2);
args.num_seqs = query.size(0); args.num_seqs = query.size(0);
args.num_heads = query.size(1); args.num_heads = query.size(1);
args.head_size = query.size(2); args.head_size = query.size(2);
...@@ -88,10 +86,8 @@ void paged_attention( ...@@ -88,10 +86,8 @@ void paged_attention(
args.k_scale = k_scale; args.k_scale = k_scale;
args.v_scale = v_scale; args.v_scale = v_scale;
hipStream_t stream = nullptr; const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
HIP_CHECK_ERROR(hipStreamCreate(&stream)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
native::paged_attention(traits, args, stream); native::paged_attention(traits, args, stream);
HIP_CHECK_ERROR(hipStreamDestroy(stream));
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment