Commit ba058648 authored by tabuchixiangcai3's avatar tabuchixiangcai3
Browse files

[DCU]Fix compilation unable to find nvte-extract_ded_and_offset


Signed-off-by: tabuchixiangcai3's avatarTangao <2205747538@qq.com>
parent 177291ac
...@@ -219,6 +219,7 @@ else() ...@@ -219,6 +219,7 @@ else()
fused_attn/flash_attn.cu fused_attn/flash_attn.cu
fused_attn/context_parallel.cu fused_attn/context_parallel.cu
fused_attn/kv_cache.cu fused_attn/kv_cache.cu
fused_attn/utils.cu
multi_tensor/adam.cu multi_tensor/adam.cu
multi_tensor/compute_scale.cu multi_tensor/compute_scale.cu
multi_tensor/l2norm.cu multi_tensor/l2norm.cu
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
namespace transformer_engine { namespace transformer_engine {
namespace fused_attn { namespace fused_attn {
#ifndef __HIP_PLATFORM_AMD__
using namespace transformer_engine; using namespace transformer_engine;
// get matrix strides based on matrix type // get matrix strides based on matrix type
...@@ -610,7 +610,7 @@ uint32_t GetRuntimeNumSegments(void *cu_seqlen, void *workspace, size_t len, cud ...@@ -610,7 +610,7 @@ uint32_t GetRuntimeNumSegments(void *cu_seqlen, void *workspace, size_t len, cud
NVTE_CHECK_CUDA(cudaStreamSynchronize(stream)); NVTE_CHECK_CUDA(cudaStreamSynchronize(stream));
return hout; return hout;
} }
#endif
__global__ void extract_seed_and_offset(int64_t *rng_state_ptr, bool captured, int64_t *seed_ptr, __global__ void extract_seed_and_offset(int64_t *rng_state_ptr, bool captured, int64_t *seed_ptr,
uint64_t seed_val, int64_t *offset_ptr, uint64_t offset_val, uint64_t seed_val, int64_t *offset_ptr, uint64_t offset_val,
uint32_t offset_intragraph) { uint32_t offset_intragraph) {
......
...@@ -7,10 +7,11 @@ ...@@ -7,10 +7,11 @@
#ifndef TRANSFORMER_ENGINE_FUSED_ATTN_UTILS_H_ #ifndef TRANSFORMER_ENGINE_FUSED_ATTN_UTILS_H_
#define TRANSFORMER_ENGINE_FUSED_ATTN_UTILS_H_ #define TRANSFORMER_ENGINE_FUSED_ATTN_UTILS_H_
#ifndef __HIP_PLATFORM_AMD__
#include <cudnn.h> #include <cudnn.h>
#include <cudnn_frontend.h> #include <cudnn_frontend.h>
#include <cudnn_frontend_utils.h> #include <cudnn_frontend_utils.h>
#endif
#include <cstdint> #include <cstdint>
#include <mutex> #include <mutex>
...@@ -19,7 +20,7 @@ ...@@ -19,7 +20,7 @@
namespace transformer_engine { namespace transformer_engine {
namespace fused_attn { namespace fused_attn {
#ifndef __HIP_PLATFORM_AMD__
using namespace transformer_engine; using namespace transformer_engine;
enum NVTE_QKV_Matrix { enum NVTE_QKV_Matrix {
...@@ -186,7 +187,7 @@ void PopulateRngStateAsync(void *rng_state_dst, const void *const seed, size_t q ...@@ -186,7 +187,7 @@ void PopulateRngStateAsync(void *rng_state_dst, const void *const seed, size_t q
cudaStream_t stream); cudaStream_t stream);
uint32_t GetRuntimeNumSegments(void *cu_seqlen, void *workspace, size_t len, cudaStream_t stream); uint32_t GetRuntimeNumSegments(void *cu_seqlen, void *workspace, size_t len, cudaStream_t stream);
#endif
} // namespace fused_attn } // namespace fused_attn
} // namespace transformer_engine } // namespace transformer_engine
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment