Unverified Commit 3a63b13d authored by Ming-Xu Huang's avatar Ming-Xu Huang Committed by GitHub
Browse files

[JAX] Fix incorrect sharding when only enable FSDP and Mem Misaligned in LN_BWD. (#379)



* [JAX] Fix incorrect sharding when only enable FSDP.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* [JAX] Add WAR to memory misaligned issues of LN BWD.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* [JAX] Reuse sm_arch for avoiding duplicate code.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* [JAX] Support multiple sizes allocation in WorkspaceManager.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* [JAX] Use template and ariadic arguments to improve multple sizes allocator.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

---------
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
parent b8ba734e
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#define TRANSFORMER_ENGINE_COMMON_UTIL_CUDA_RUNTIME_H_ #define TRANSFORMER_ENGINE_COMMON_UTIL_CUDA_RUNTIME_H_
#include <cuda_runtime_api.h> #include <cuda_runtime_api.h>
#include <string>
namespace transformer_engine { namespace transformer_engine {
......
...@@ -273,7 +273,7 @@ void Gemm(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque ...@@ -273,7 +273,7 @@ void Gemm(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque
auto null_tensor = TensorWrapper(nullptr, std::vector<size_t>{0}, DType::kFloat32); auto null_tensor = TensorWrapper(nullptr, std::vector<size_t>{0}, DType::kFloat32);
size_t workspace_size = kCublasLtForwardWorkspaceSize; size_t workspace_size = kCublasLtForwardWorkspaceSize;
auto *workspace = cublasLtMetaManager::Instance().GetWorkspace(workspace_size); auto *workspace = WorkspaceManager::Instance().GetWorkspace(workspace_size);
auto wk_tensor = TensorWrapper(workspace, std::vector<size_t>{workspace_size}, DType::kByte); auto wk_tensor = TensorWrapper(workspace, std::vector<size_t>{workspace_size}, DType::kByte);
nvte_cublas_gemm(A_tensor.data(), B_tensor.data(), D_tensor.data(), null_tensor.data(), nvte_cublas_gemm(A_tensor.data(), B_tensor.data(), D_tensor.data(), null_tensor.data(),
...@@ -327,7 +327,7 @@ void LayerNormForwardImpl(size_t n, size_t hidden, bool zero_centered_gamma, flo ...@@ -327,7 +327,7 @@ void LayerNormForwardImpl(size_t n, size_t hidden, bool zero_centered_gamma, flo
dummy_workspace_tensor.shape().data[0] * typeToSize(dummy_workspace_tensor.dtype()) + dummy_workspace_tensor.shape().data[0] * typeToSize(dummy_workspace_tensor.dtype()) +
dummy_barrier_tensor.shape().data[0] * typeToSize(dummy_barrier_tensor.dtype()); dummy_barrier_tensor.shape().data[0] * typeToSize(dummy_barrier_tensor.dtype());
void *workspace = cublasLtMetaManager::Instance().GetWorkspace(workspace_size); void *workspace = WorkspaceManager::Instance().GetWorkspace(workspace_size);
auto workspace_tensor = auto workspace_tensor =
TensorWrapper(workspace, dummy_workspace_tensor.shape(), dummy_workspace_tensor.dtype()); TensorWrapper(workspace, dummy_workspace_tensor.shape(), dummy_workspace_tensor.dtype());
...@@ -412,13 +412,9 @@ void LayerNormBackwardImpl(size_t n, size_t hidden, bool zero_centered_gamma, fl ...@@ -412,13 +412,9 @@ void LayerNormBackwardImpl(size_t n, size_t hidden, bool zero_centered_gamma, fl
size_t dgamma_part_size = dummy_dgamma_part_tensor.shape().data[0] * size_t dgamma_part_size = dummy_dgamma_part_tensor.shape().data[0] *
dummy_dgamma_part_tensor.shape().data[1] * dummy_dgamma_part_tensor.shape().data[1] *
typeToSize(dummy_dgamma_part_tensor.dtype()); typeToSize(dummy_dgamma_part_tensor.dtype());
size_t total_workspace_size =
(workspace_size + barrier_size + dgamma_part_size + dbeta_part_size);
void *workspace = cublasLtMetaManager::Instance().GetWorkspace(total_workspace_size); auto [workspace, dgamma_part, dbeta_part, barrier] = WorkspaceManager::Instance().GetWorkspace(
void *barrier = static_cast<char *>(workspace) + workspace_size; workspace_size, dgamma_part_size, dbeta_part_size, barrier_size);
void *dgamma_part = static_cast<char *>(barrier) + barrier_size;
void *dbeta_part = static_cast<char *>(dgamma_part) + dgamma_part_size;
auto workspace_tensor = auto workspace_tensor =
TensorWrapper(workspace, dummy_workspace_tensor.shape(), dummy_workspace_tensor.dtype()); TensorWrapper(workspace, dummy_workspace_tensor.shape(), dummy_workspace_tensor.dtype());
...@@ -811,7 +807,7 @@ void SelfFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaqu ...@@ -811,7 +807,7 @@ void SelfFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaqu
output_s->data.dptr = softmax_aux; output_s->data.dptr = softmax_aux;
auto workspace_size = query_workspace_tensor.shape().data[0]; auto workspace_size = query_workspace_tensor.shape().data[0];
auto *workspace = cublasLtMetaManager::Instance().GetWorkspace(workspace_size); auto *workspace = WorkspaceManager::Instance().GetWorkspace(workspace_size);
auto workspace_tensor = auto workspace_tensor =
TensorWrapper(workspace, query_workspace_tensor.shape(), query_workspace_tensor.dtype()); TensorWrapper(workspace, query_workspace_tensor.shape(), query_workspace_tensor.dtype());
...@@ -894,7 +890,7 @@ void SelfFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaq ...@@ -894,7 +890,7 @@ void SelfFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaq
query_workspace_tensor.data(), stream); query_workspace_tensor.data(), stream);
size_t workspace_size = query_workspace_tensor.shape().data[0]; size_t workspace_size = query_workspace_tensor.shape().data[0];
auto *workspace = cublasLtMetaManager::Instance().GetWorkspace(workspace_size); auto *workspace = WorkspaceManager::Instance().GetWorkspace(workspace_size);
auto workspace_tensor = auto workspace_tensor =
TensorWrapper(workspace, query_workspace_tensor.shape(), query_workspace_tensor.dtype()); TensorWrapper(workspace, query_workspace_tensor.shape(), query_workspace_tensor.dtype());
...@@ -978,7 +974,7 @@ void CrossFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaq ...@@ -978,7 +974,7 @@ void CrossFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaq
query_workspace_tensor.shape().data[0] * typeToSize(query_workspace_tensor.dtype()); query_workspace_tensor.shape().data[0] * typeToSize(query_workspace_tensor.dtype());
auto rng_workspace_size = 2 * sizeof(int64_t); auto rng_workspace_size = 2 * sizeof(int64_t);
auto total_workspace_size = plan_workspace_size + rng_workspace_size; auto total_workspace_size = plan_workspace_size + rng_workspace_size;
auto *workspace = cublasLtMetaManager::Instance().GetWorkspace(total_workspace_size); auto *workspace = WorkspaceManager::Instance().GetWorkspace(total_workspace_size);
auto workspace_tensor = auto workspace_tensor =
TensorWrapper(workspace, query_workspace_tensor.shape(), query_workspace_tensor.dtype()); TensorWrapper(workspace, query_workspace_tensor.shape(), query_workspace_tensor.dtype());
...@@ -1074,7 +1070,7 @@ void CrossFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opa ...@@ -1074,7 +1070,7 @@ void CrossFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opa
size_t workspace_size = size_t workspace_size =
query_workspace_tensor.shape().data[0] * typeToSize(query_workspace_tensor.dtype()); query_workspace_tensor.shape().data[0] * typeToSize(query_workspace_tensor.dtype());
auto *workspace = cublasLtMetaManager::Instance().GetWorkspace(workspace_size); auto *workspace = WorkspaceManager::Instance().GetWorkspace(workspace_size);
auto workspace_tensor = auto workspace_tensor =
TensorWrapper(workspace, query_workspace_tensor.shape(), query_workspace_tensor.dtype()); TensorWrapper(workspace, query_workspace_tensor.shape(), query_workspace_tensor.dtype());
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include <cuda_runtime_api.h> #include <cuda_runtime_api.h>
#include <cassert> #include <cassert>
#include "common/util/cuda_runtime.h"
#include "utils.h" #include "utils.h"
namespace transformer_engine { namespace transformer_engine {
...@@ -17,20 +18,7 @@ int GetCudaRuntimeVersion() { ...@@ -17,20 +18,7 @@ int GetCudaRuntimeVersion() {
return ver; return ver;
} }
int GetDeviceComputeCapability(int gpu_id) { int GetDeviceComputeCapability(int gpu_id) { return transformer_engine::cuda::sm_arch(gpu_id); }
int max_num_gpu = 0;
NVTE_CHECK_CUDA(cudaGetDeviceCount(&max_num_gpu));
assert(gpu_id < max_num_gpu);
int major = 0;
NVTE_CHECK_CUDA(cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, gpu_id));
int minor = 0;
NVTE_CHECK_CUDA(cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, gpu_id));
int gpu_arch = major * 10 + minor;
return gpu_arch;
}
__global__ void populate_rng_state_kernel(int64_t *rng_state_dst, const int64_t *const seed, __global__ void populate_rng_state_kernel(int64_t *rng_state_dst, const int64_t *const seed,
int64_t offset) { int64_t offset) {
......
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
#include <pybind11/pybind11.h> #include <pybind11/pybind11.h>
#include <cstdint> #include <cstdint>
#include <numeric>
#include <stdexcept> #include <stdexcept>
#include <string> #include <string>
#include <type_traits> #include <type_traits>
...@@ -26,25 +27,44 @@ void PopulateRngStateAsync(void *rng_state_dst, const void *const seed, size_t q ...@@ -26,25 +27,44 @@ void PopulateRngStateAsync(void *rng_state_dst, const void *const seed, size_t q
size_t kv_max_seqlen, NVTE_Fused_Attn_Backend backend, size_t kv_max_seqlen, NVTE_Fused_Attn_Backend backend,
cudaStream_t stream); cudaStream_t stream);
class cublasLtMetaManager { class WorkspaceManager {
public: public:
static cublasLtMetaManager &Instance() { static WorkspaceManager &Instance() {
static thread_local cublasLtMetaManager instance; static thread_local WorkspaceManager instance;
return instance; return instance;
} }
cublasLtMetaManager() {} WorkspaceManager() {}
~cublasLtMetaManager() { Clear_(); } ~WorkspaceManager() { Clear_(); }
void *GetWorkspace(size_t size = 4194304) { void *GetWorkspace(size_t size = 4194304) {
ReallocateIfNeed_(size); ReallocateIfNeed_(size);
return workspace_; return workspace_;
} }
template <typename... Args>
inline auto GetWorkspace(Args... args) {
auto asks = std::array<size_t, sizeof...(Args)>{args...};
std::array<size_t, sizeof...(Args) + 1> offsets = {0};
std::array<void *, sizeof...(Args)> workspaces = {nullptr};
std::transform_inclusive_scan(
asks.cbegin(), asks.cend(), offsets.begin() + 1, std::plus<size_t>{},
[=](auto x) { return PadSize_(x); }, 0);
auto *workspace = GetWorkspace(offsets.back());
std::transform(offsets.cbegin(), offsets.cend() - 1, workspaces.begin(),
[workspace](auto x) { return static_cast<char *>(workspace) + x; });
return workspaces;
}
private: private:
void *workspace_ = nullptr; void *workspace_ = nullptr;
size_t size_ = 0; size_t size_ = 0;
size_t PadSize_(size_t size) {
constexpr size_t alignment = 128;
return ((size + alignment - 1) / alignment) * alignment;
}
void Clear_() { void Clear_() {
if (workspace_ != nullptr) { if (workspace_ != nullptr) {
NVTE_CHECK_CUDA(cudaFree(workspace_)); NVTE_CHECK_CUDA(cudaFree(workspace_));
...@@ -54,6 +74,7 @@ class cublasLtMetaManager { ...@@ -54,6 +74,7 @@ class cublasLtMetaManager {
} }
void Allocate_(size_t new_size) { void Allocate_(size_t new_size) {
new_size = PadSize_(new_size);
NVTE_CHECK_CUDA(cudaMalloc(&workspace_, new_size)); NVTE_CHECK_CUDA(cudaMalloc(&workspace_, new_size));
size_ = new_size; size_ = new_size;
} }
......
...@@ -138,7 +138,7 @@ def infer_major_sharding_type() -> MajorShardingType: ...@@ -138,7 +138,7 @@ def infer_major_sharding_type() -> MajorShardingType:
""" """
gsr = global_shard_resource() gsr = global_shard_resource()
resources = [gsr.dp_resource, gsr.tp_resource] resources = [gsr.dp_resource, gsr.tp_resource, gsr.fsdp_resource]
for idx, rs in enumerate(resources): for idx, rs in enumerate(resources):
try: try:
size, _ = _get_mesh_info(rs) size, _ = _get_mesh_info(rs)
...@@ -149,12 +149,15 @@ def infer_major_sharding_type() -> MajorShardingType: ...@@ -149,12 +149,15 @@ def infer_major_sharding_type() -> MajorShardingType:
dp_resource = resources[0] dp_resource = resources[0]
tp_resource = resources[1] tp_resource = resources[1]
fsdp_resource = resources[2]
if dp_resource is not None and \ def dp_enabled():
tp_resource is not None : return (fsdp_resource is not None) or (dp_resource is not None)
if dp_enabled() and tp_resource is not None:
return MajorShardingType.DPTP return MajorShardingType.DPTP
if dp_resource is not None: if dp_enabled():
return MajorShardingType.DP return MajorShardingType.DP
if tp_resource is not None: if tp_resource is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment