Unverified Commit cfbbfb89 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Cleanup pytorch extensions (#1781)



* rm unused swizzle extensions
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix swizzle
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Consistent namespaces and first refactor
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* format and lint
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* transformer_engine
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* revert accidental perm change
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent f966d5f7
...@@ -6,8 +6,9 @@ ...@@ -6,8 +6,9 @@
#include "extensions.h" #include "extensions.h"
namespace transformer_engine::pytorch {
at::Tensor scaled_softmax_forward(at::Tensor input, float scale_factor) { at::Tensor scaled_softmax_forward(at::Tensor input, float scale_factor) {
using namespace transformer_engine::pytorch;
AT_ASSERTM(input.dim() == 4, "expected 4D tensor"); AT_ASSERTM(input.dim() == 4, "expected 4D tensor");
AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) || AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) ||
(input.scalar_type() == at::ScalarType::BFloat16), (input.scalar_type() == at::ScalarType::BFloat16),
...@@ -38,8 +39,6 @@ at::Tensor scaled_softmax_forward(at::Tensor input, float scale_factor) { ...@@ -38,8 +39,6 @@ at::Tensor scaled_softmax_forward(at::Tensor input, float scale_factor) {
at::Tensor scaled_softmax_backward(at::Tensor output_grad_, at::Tensor softmax_results_, at::Tensor scaled_softmax_backward(at::Tensor output_grad_, at::Tensor softmax_results_,
float scale_factor) { float scale_factor) {
using namespace transformer_engine::pytorch;
auto output_grads = output_grad_.contiguous(); auto output_grads = output_grad_.contiguous();
auto softmax_results = softmax_results_.contiguous(); auto softmax_results = softmax_results_.contiguous();
...@@ -65,8 +64,6 @@ at::Tensor scaled_softmax_backward(at::Tensor output_grad_, at::Tensor softmax_r ...@@ -65,8 +64,6 @@ at::Tensor scaled_softmax_backward(at::Tensor output_grad_, at::Tensor softmax_r
} }
at::Tensor scaled_masked_softmax_forward(at::Tensor input, at::Tensor mask, float scale_factor) { at::Tensor scaled_masked_softmax_forward(at::Tensor input, at::Tensor mask, float scale_factor) {
using namespace transformer_engine::pytorch;
AT_ASSERTM(input.dim() == 4, "expected 4D tensor"); AT_ASSERTM(input.dim() == 4, "expected 4D tensor");
AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) || AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) ||
(input.scalar_type() == at::ScalarType::BFloat16), (input.scalar_type() == at::ScalarType::BFloat16),
...@@ -105,8 +102,6 @@ at::Tensor scaled_masked_softmax_forward(at::Tensor input, at::Tensor mask, floa ...@@ -105,8 +102,6 @@ at::Tensor scaled_masked_softmax_forward(at::Tensor input, at::Tensor mask, floa
at::Tensor scaled_masked_softmax_backward(at::Tensor output_grad_, at::Tensor softmax_results_, at::Tensor scaled_masked_softmax_backward(at::Tensor output_grad_, at::Tensor softmax_results_,
float scale_factor) { float scale_factor) {
using namespace transformer_engine::pytorch;
auto output_grads = output_grad_.contiguous(); auto output_grads = output_grad_.contiguous();
auto softmax_results = softmax_results_.contiguous(); auto softmax_results = softmax_results_.contiguous();
...@@ -132,8 +127,6 @@ at::Tensor scaled_masked_softmax_backward(at::Tensor output_grad_, at::Tensor so ...@@ -132,8 +127,6 @@ at::Tensor scaled_masked_softmax_backward(at::Tensor output_grad_, at::Tensor so
} }
at::Tensor scaled_upper_triang_masked_softmax_forward(at::Tensor input, float scale_factor) { at::Tensor scaled_upper_triang_masked_softmax_forward(at::Tensor input, float scale_factor) {
using namespace transformer_engine::pytorch;
AT_ASSERTM(input.dim() == 3, "expected 3D tensor"); AT_ASSERTM(input.dim() == 3, "expected 3D tensor");
AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) || AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) ||
(input.scalar_type() == at::ScalarType::BFloat16), (input.scalar_type() == at::ScalarType::BFloat16),
...@@ -159,8 +152,6 @@ at::Tensor scaled_upper_triang_masked_softmax_forward(at::Tensor input, float sc ...@@ -159,8 +152,6 @@ at::Tensor scaled_upper_triang_masked_softmax_forward(at::Tensor input, float sc
at::Tensor scaled_upper_triang_masked_softmax_backward(at::Tensor output_grads_, at::Tensor scaled_upper_triang_masked_softmax_backward(at::Tensor output_grads_,
at::Tensor softmax_results_, at::Tensor softmax_results_,
float scale_factor) { float scale_factor) {
using namespace transformer_engine::pytorch;
auto output_grads = output_grads_.contiguous(); auto output_grads = output_grads_.contiguous();
auto softmax_results = softmax_results_.contiguous(); auto softmax_results = softmax_results_.contiguous();
...@@ -188,7 +179,6 @@ at::Tensor scaled_upper_triang_masked_softmax_backward(at::Tensor output_grads_, ...@@ -188,7 +179,6 @@ at::Tensor scaled_upper_triang_masked_softmax_backward(at::Tensor output_grads_,
} }
at::Tensor scaled_aligned_causal_masked_softmax_forward(at::Tensor input, float scale_factor) { at::Tensor scaled_aligned_causal_masked_softmax_forward(at::Tensor input, float scale_factor) {
using namespace transformer_engine::pytorch;
AT_ASSERTM(input.dim() == 4, "expected 4D tensor"); AT_ASSERTM(input.dim() == 4, "expected 4D tensor");
AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) || AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) ||
(input.scalar_type() == at::ScalarType::BFloat16), (input.scalar_type() == at::ScalarType::BFloat16),
...@@ -220,8 +210,6 @@ at::Tensor scaled_aligned_causal_masked_softmax_forward(at::Tensor input, float ...@@ -220,8 +210,6 @@ at::Tensor scaled_aligned_causal_masked_softmax_forward(at::Tensor input, float
at::Tensor scaled_aligned_causal_masked_softmax_backward(at::Tensor output_grad_, at::Tensor scaled_aligned_causal_masked_softmax_backward(at::Tensor output_grad_,
at::Tensor softmax_results_, at::Tensor softmax_results_,
float scale_factor) { float scale_factor) {
using namespace transformer_engine::pytorch;
auto output_grads = output_grad_.contiguous(); auto output_grads = output_grad_.contiguous();
auto softmax_results = softmax_results_.contiguous(); auto softmax_results = softmax_results_.contiguous();
...@@ -245,3 +233,5 @@ at::Tensor scaled_aligned_causal_masked_softmax_backward(at::Tensor output_grad_ ...@@ -245,3 +233,5 @@ at::Tensor scaled_aligned_causal_masked_softmax_backward(at::Tensor output_grad_
return output_grads; return output_grads;
} }
} // namespace transformer_engine::pytorch
...@@ -13,13 +13,12 @@ namespace transformer_engine::pytorch { ...@@ -13,13 +13,12 @@ namespace transformer_engine::pytorch {
std::vector<py::object> fused_multi_quantize(std::vector<at::Tensor> input_list, std::vector<py::object> fused_multi_quantize(std::vector<at::Tensor> input_list,
std::optional<std::vector<py::object>> output_list, std::optional<std::vector<py::object>> output_list,
std::vector<py::handle> quantizer_list, std::vector<py::handle> quantizer_list, DType otype) {
transformer_engine::DType otype) {
init_extension(); init_extension();
std::vector<NVTETensor> nvte_tensor_input_list; std::vector<NVTETensor> nvte_tensor_input_list;
std::vector<NVTETensor> nvte_tensor_output_list; std::vector<NVTETensor> nvte_tensor_output_list;
std::vector<py::object> py_output_objects_list; std::vector<py::object> py_output_objects_list;
std::vector<transformer_engine::TensorWrapper> tensor_wrappers; std::vector<TensorWrapper> tensor_wrappers;
if (output_list.has_value()) { if (output_list.has_value()) {
py_output_objects_list = output_list.value(); py_output_objects_list = output_list.value();
} }
...@@ -33,7 +32,7 @@ std::vector<py::object> fused_multi_quantize(std::vector<at::Tensor> input_list, ...@@ -33,7 +32,7 @@ std::vector<py::object> fused_multi_quantize(std::vector<at::Tensor> input_list,
auto input_tensor = makeTransformerEngineTensor(input_list[i]); auto input_tensor = makeTransformerEngineTensor(input_list[i]);
const NVTEShape input_shape = input_tensor.shape(); const NVTEShape input_shape = input_tensor.shape();
transformer_engine::TensorWrapper output_tensor; TensorWrapper output_tensor;
if (!detail::IsFloat8Quantizers(quantizer_list[i].ptr())) { if (!detail::IsFloat8Quantizers(quantizer_list[i].ptr())) {
with_fused_kernel = false; with_fused_kernel = false;
...@@ -80,8 +79,7 @@ std::vector<py::object> fused_multi_quantize(std::vector<at::Tensor> input_list, ...@@ -80,8 +79,7 @@ std::vector<py::object> fused_multi_quantize(std::vector<at::Tensor> input_list,
return py_output_objects_list; return py_output_objects_list;
} }
at::Tensor fp8_transpose(at::Tensor input, transformer_engine::DType otype, at::Tensor fp8_transpose(at::Tensor input, DType otype, std::optional<at::Tensor> output) {
std::optional<at::Tensor> output) {
init_extension(); init_extension();
const auto dim = input.dim(); const auto dim = input.dim();
......
...@@ -4,10 +4,10 @@ ...@@ -4,10 +4,10 @@
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
#include "extensions.h"
#include "transformer_engine/transformer_engine.h"
#include "util.h" #include "util.h"
#include "common.h"
std::optional<at::Tensor> swizzle_scaling_factors(transformer_engine::TensorWrapper& input, std::optional<at::Tensor> swizzle_scaling_factors(transformer_engine::TensorWrapper& input,
bool rowwise) { bool rowwise) {
using namespace transformer_engine::pytorch; using namespace transformer_engine::pytorch;
...@@ -45,80 +45,33 @@ std::optional<at::Tensor> swizzle_scaling_factors(transformer_engine::TensorWrap ...@@ -45,80 +45,33 @@ std::optional<at::Tensor> swizzle_scaling_factors(transformer_engine::TensorWrap
transformer_engine::TensorWrapper input_cu(NVTE_MXFP8_1D_SCALING); transformer_engine::TensorWrapper input_cu(NVTE_MXFP8_1D_SCALING);
transformer_engine::TensorWrapper output_cu(NVTE_MXFP8_1D_SCALING); transformer_engine::TensorWrapper output_cu(NVTE_MXFP8_1D_SCALING);
if (rowwise) { if (rowwise) {
input_cu.set_rowwise_data(input.dptr(), DType::kFloat8E4M3, input_shape); input_cu.set_rowwise_data(input.dptr(), transformer_engine::DType::kFloat8E4M3, input_shape);
input_cu.set_rowwise_scale_inv(scale_inv_dptr, DType::kFloat8E8M0, scale_inv_shape); input_cu.set_rowwise_scale_inv(scale_inv_dptr, transformer_engine::DType::kFloat8E8M0,
output_cu.set_rowwise_data(input.dptr(), DType::kFloat8E4M3, input_shape); scale_inv_shape);
output_cu.set_rowwise_scale_inv(swizzled_scale_inv_dptr, DType::kFloat8E8M0, scale_inv_shape); output_cu.set_rowwise_data(input.dptr(), transformer_engine::DType::kFloat8E4M3, input_shape);
output_cu.set_rowwise_scale_inv(swizzled_scale_inv_dptr, transformer_engine::DType::kFloat8E8M0,
scale_inv_shape);
} else { } else {
input_cu.set_columnwise_data(input.columnwise_dptr(), DType::kFloat8E4M3, input_shape); input_cu.set_columnwise_data(input.columnwise_dptr(), transformer_engine::DType::kFloat8E4M3,
input_cu.set_columnwise_scale_inv(scale_inv_dptr, DType::kFloat8E8M0, scale_inv_shape); input_shape);
output_cu.set_columnwise_data(input.columnwise_dptr(), DType::kFloat8E4M3, input_shape); input_cu.set_columnwise_scale_inv(scale_inv_dptr, transformer_engine::DType::kFloat8E8M0,
output_cu.set_columnwise_scale_inv(swizzled_scale_inv_dptr, DType::kFloat8E8M0,
scale_inv_shape); scale_inv_shape);
output_cu.set_columnwise_data(input.columnwise_dptr(), transformer_engine::DType::kFloat8E4M3,
input_shape);
output_cu.set_columnwise_scale_inv(swizzled_scale_inv_dptr,
transformer_engine::DType::kFloat8E8M0, scale_inv_shape);
} }
// Launch kernel // Launch kernel
nvte_swizzle_scaling_factors(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); nvte_swizzle_scaling_factors(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream());
if (rowwise) { if (rowwise) {
input.set_rowwise_scale_inv(swizzled_scale_inv_dptr, DType::kFloat8E8M0, scale_inv_shape); input.set_rowwise_scale_inv(swizzled_scale_inv_dptr, transformer_engine::DType::kFloat8E8M0,
scale_inv_shape);
} else { } else {
input.set_columnwise_scale_inv(swizzled_scale_inv_dptr, DType::kFloat8E8M0, scale_inv_shape); input.set_columnwise_scale_inv(swizzled_scale_inv_dptr, transformer_engine::DType::kFloat8E8M0,
} scale_inv_shape);
return swizzled_scale_inv;
}
at::Tensor rowwise_swizzle(at::Tensor input, at::Tensor scale_inv) {
using namespace transformer_engine::pytorch;
NVTE_CHECK(input.element_size() == 1, "8-bit input required for swizzling scaling factors.");
auto options = at::TensorOptions().dtype(scale_inv.dtype()).device(torch::kCUDA);
auto swizzled_scale_inv = at::empty_like(scale_inv, options);
void* scale_inv_dptr = getDataPtr(scale_inv, 0);
void* swizzled_scale_inv_dptr = getDataPtr(swizzled_scale_inv, 0);
auto input_cu = makeTransformerEngineTensor(input.data_ptr(), getTensorShape(input),
DType::kFloat8E4M3, nullptr, nullptr, scale_inv_dptr,
getTensorShape(scale_inv), NVTE_MXFP8_1D_SCALING);
auto output_cu = makeTransformerEngineTensor(
input.data_ptr(), getTensorShape(input), DType::kFloat8E4M3, nullptr, nullptr,
swizzled_scale_inv_dptr, getTensorShape(swizzled_scale_inv), NVTE_MXFP8_1D_SCALING);
// Launch kernel
nvte_swizzle_scaling_factors(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream());
return swizzled_scale_inv;
}
at::Tensor columnwise_swizzle(at::Tensor input, at::Tensor scale_inv) {
using namespace transformer_engine::pytorch;
NVTE_CHECK(input.element_size() == 1, "8-bit input required for swizzling scaling factors.");
auto options = at::TensorOptions().dtype(scale_inv.dtype()).device(torch::kCUDA);
auto swizzled_scale_inv = at::empty_like(scale_inv, options);
// Return immediately if tensor is empty
if (scale_inv.numel() == 0) {
return swizzled_scale_inv;
} }
void* scale_inv_dptr = getDataPtr(scale_inv, 0);
void* swizzled_scale_inv_dptr = getDataPtr(swizzled_scale_inv, 0);
auto input_cu = makeTransformerEngineTensor(
nullptr, input.data_ptr(), {1}, getTensorShape(input), DType::kFloat8E4M3, nullptr, nullptr,
nullptr, scale_inv_dptr, {1}, getTensorShape(scale_inv), NVTE_MXFP8_1D_SCALING);
auto output_cu = makeTransformerEngineTensor(
nullptr, input.data_ptr(), {1}, getTensorShape(input), DType::kFloat8E4M3, nullptr, nullptr,
nullptr, swizzled_scale_inv_dptr, {1}, getTensorShape(swizzled_scale_inv),
NVTE_MXFP8_1D_SCALING);
// Launch kernel
nvte_swizzle_scaling_factors(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream());
return swizzled_scale_inv; return swizzled_scale_inv;
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment