Commit 4086a4cc authored by wenjh's avatar wenjh
Browse files

Merge branch 'TE_ta_develop2.9' into 'develop_v2.9'

[DCU]Fix compilation unable to find nvte-extract_ded_and_offset

See merge request dcutoolkit/deeplearing/TransformerEngine!64
parents 177291ac 121d9224
...@@ -432,7 +432,7 @@ def test_fuser_ops_with_userbuffers( ...@@ -432,7 +432,7 @@ def test_fuser_ops_with_userbuffers(
command = [] command = []
if tex.ubuf_built_with_mpi(): if tex.ubuf_built_with_mpi():
python_exe = pathlib.Path(sys.executable).resolve() python_exe = pathlib.Path(sys.executable).resolve()
command.extend(("mpirun", "-np", str(world_size), "--oversubscribe", "--quiet", python_exe)) command.extend(("mpirun", "-np", str(world_size), "--oversubscribe", "--allow-run-as-root", python_exe))
else: else:
command.extend(("torchrun", f"--nproc_per_node={world_size}")) command.extend(("torchrun", f"--nproc_per_node={world_size}"))
......
...@@ -219,6 +219,7 @@ else() ...@@ -219,6 +219,7 @@ else()
fused_attn/flash_attn.cu fused_attn/flash_attn.cu
fused_attn/context_parallel.cu fused_attn/context_parallel.cu
fused_attn/kv_cache.cu fused_attn/kv_cache.cu
fused_attn/utils.cu
multi_tensor/adam.cu multi_tensor/adam.cu
multi_tensor/compute_scale.cu multi_tensor/compute_scale.cu
multi_tensor/l2norm.cu multi_tensor/l2norm.cu
......
...@@ -44,7 +44,17 @@ __device__ __forceinline__ uint32_t bytewise_less_than(uint32_t a, uint32_t b) { ...@@ -44,7 +44,17 @@ __device__ __forceinline__ uint32_t bytewise_less_than(uint32_t a, uint32_t b) {
// Bitwise logical op to get answer in MSBs // Bitwise logical op to get answer in MSBs
// Equivalent logic: result = (a == b) ? !result : b // Equivalent logic: result = (a == b) ? !result : b
#ifdef __HIP_PLATFORM_AMD__ #ifdef __HIP_PLATFORM_AMD__
result = (a == b) ? !result : b; // Use HIP vector types for byte-wise parallel comparison
union { uint32_t u32; uint8_t bytes[4]; } a_union, b_union;
a_union.u32 = a;
b_union.u32 = b;
uint32_t mask = 0;
mask |= (a_union.bytes[0] < b_union.bytes[0]) ? 0x80000000U : 0;
mask |= (a_union.bytes[1] < b_union.bytes[1]) ? 0x00800000U : 0;
mask |= (a_union.bytes[2] < b_union.bytes[2]) ? 0x00008000U : 0;
mask |= (a_union.bytes[3] < b_union.bytes[3]) ? 0x00000080U : 0;
result = mask;
#else #else
asm("lop3.b32 %0, %1, %2, %3, 0x4D;\n\t" : "=r"(result) : "r"(a), "r"(b), "r"(result)); asm("lop3.b32 %0, %1, %2, %3, 0x4D;\n\t" : "=r"(result) : "r"(a), "r"(b), "r"(result));
#endif #endif
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
namespace transformer_engine { namespace transformer_engine {
namespace fused_attn { namespace fused_attn {
#ifndef __HIP_PLATFORM_AMD__
using namespace transformer_engine; using namespace transformer_engine;
// get matrix strides based on matrix type // get matrix strides based on matrix type
...@@ -610,7 +610,7 @@ uint32_t GetRuntimeNumSegments(void *cu_seqlen, void *workspace, size_t len, cud ...@@ -610,7 +610,7 @@ uint32_t GetRuntimeNumSegments(void *cu_seqlen, void *workspace, size_t len, cud
NVTE_CHECK_CUDA(cudaStreamSynchronize(stream)); NVTE_CHECK_CUDA(cudaStreamSynchronize(stream));
return hout; return hout;
} }
#endif
__global__ void extract_seed_and_offset(int64_t *rng_state_ptr, bool captured, int64_t *seed_ptr, __global__ void extract_seed_and_offset(int64_t *rng_state_ptr, bool captured, int64_t *seed_ptr,
uint64_t seed_val, int64_t *offset_ptr, uint64_t offset_val, uint64_t seed_val, int64_t *offset_ptr, uint64_t offset_val,
uint32_t offset_intragraph) { uint32_t offset_intragraph) {
......
...@@ -7,10 +7,11 @@ ...@@ -7,10 +7,11 @@
#ifndef TRANSFORMER_ENGINE_FUSED_ATTN_UTILS_H_ #ifndef TRANSFORMER_ENGINE_FUSED_ATTN_UTILS_H_
#define TRANSFORMER_ENGINE_FUSED_ATTN_UTILS_H_ #define TRANSFORMER_ENGINE_FUSED_ATTN_UTILS_H_
#ifndef __HIP_PLATFORM_AMD__
#include <cudnn.h> #include <cudnn.h>
#include <cudnn_frontend.h> #include <cudnn_frontend.h>
#include <cudnn_frontend_utils.h> #include <cudnn_frontend_utils.h>
#endif
#include <cstdint> #include <cstdint>
#include <mutex> #include <mutex>
...@@ -19,7 +20,7 @@ ...@@ -19,7 +20,7 @@
namespace transformer_engine { namespace transformer_engine {
namespace fused_attn { namespace fused_attn {
#ifndef __HIP_PLATFORM_AMD__
using namespace transformer_engine; using namespace transformer_engine;
enum NVTE_QKV_Matrix { enum NVTE_QKV_Matrix {
...@@ -186,7 +187,7 @@ void PopulateRngStateAsync(void *rng_state_dst, const void *const seed, size_t q ...@@ -186,7 +187,7 @@ void PopulateRngStateAsync(void *rng_state_dst, const void *const seed, size_t q
cudaStream_t stream); cudaStream_t stream);
uint32_t GetRuntimeNumSegments(void *cu_seqlen, void *workspace, size_t len, cudaStream_t stream); uint32_t GetRuntimeNumSegments(void *cu_seqlen, void *workspace, size_t len, cudaStream_t stream);
#endif
} // namespace fused_attn } // namespace fused_attn
} // namespace transformer_engine } // namespace transformer_engine
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment