Unverified Commit ed74d978 authored by Donggeun Yu's avatar Donggeun Yu Committed by GitHub
Browse files

DeformableDETR support bfloat16 (#29232)



* Update ms_deform_attn_cuda.cu

* Update ms_deform_attn_cuda.cuh

* Update modeling_deformable_detr.py

* Update src/transformers/models/deformable_detr/modeling_deformable_detr.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update modeling_deformable_detr.py

* python utils/check_copies.py --fix_and_overwrite

* Fix dtype missmatch error

* Update test_modeling_deformable_detr.py

* Update test_modeling_deformable_detr.py

* Update modeling_deformable_detr.py

* Update modeling_deformable_detr.py

* Support DeformableDETR with bfloat16

* Add test code

* Use AT_DISPATCH_FLOATING_TYPES_AND2

Use AT_DISPATCH_FLOATING_TYPES_AND2

* Update tests/models/deformable_detr/test_modeling_deformable_detr.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update tests/models/deformable_detr/test_modeling_deformable_detr.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Fix not found require_torch_bf16 function

---------
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
parent bcd23a54
......@@ -64,7 +64,7 @@ at::Tensor ms_deform_attn_cuda_forward(
for (int n = 0; n < batch/im2col_step_; ++n)
{
auto columns = output_n.select(0, n);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(value.type(), "ms_deform_attn_forward_cuda", ([&] {
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, value.type(), "ms_deform_attn_forward_cuda", ([&] {
ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(),
value.data<scalar_t>() + n * im2col_step_ * per_value_size,
spatial_shapes.data<int64_t>(),
......@@ -134,7 +134,7 @@ std::vector<at::Tensor> ms_deform_attn_cuda_backward(
for (int n = 0; n < batch/im2col_step_; ++n)
{
auto grad_output_g = grad_output_n.select(0, n);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(value.type(), "ms_deform_attn_backward_cuda", ([&] {
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, value.type(), "ms_deform_attn_backward_cuda", ([&] {
ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(),
grad_output_g.data<scalar_t>(),
value.data<scalar_t>() + n * im2col_step_ * per_value_size,
......
......@@ -72,7 +72,7 @@ at::Tensor ms_deform_attn_cuda_forward(
for (int n = 0; n < batch/im2col_step_; ++n)
{
auto columns = output_n.select(0, n);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(value.type(), "ms_deform_attn_forward_cuda", ([&] {
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, value.type(), "ms_deform_attn_forward_cuda", ([&] {
ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(),
value.data<scalar_t>() + n * im2col_step_ * per_value_size,
spatial_shapes.data<int64_t>(),
......@@ -142,7 +142,7 @@ std::vector<at::Tensor> ms_deform_attn_cuda_backward(
for (int n = 0; n < batch/im2col_step_; ++n)
{
auto grad_output_g = grad_output_n.select(0, n);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(value.type(), "ms_deform_attn_backward_cuda", ([&] {
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, value.type(), "ms_deform_attn_backward_cuda", ([&] {
ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(),
grad_output_g.data<scalar_t>(),
value.data<scalar_t>() + n * im2col_step_ * per_value_size,
......
......@@ -19,6 +19,14 @@ at::Tensor ms_deform_attn_cuda_forward(
const at::Tensor &attn_weight,
const int im2col_step);
at::Tensor ms_deform_attn_cuda_forward_bf16(
const at::Tensor &value,
const at::Tensor &spatial_shapes,
const at::Tensor &level_start_index,
const at::Tensor &sampling_loc,
const at::Tensor &attn_weight,
const int im2col_step);
std::vector<at::Tensor> ms_deform_attn_cuda_backward(
const at::Tensor &value,
const at::Tensor &spatial_shapes,
......@@ -27,3 +35,12 @@ std::vector<at::Tensor> ms_deform_attn_cuda_backward(
const at::Tensor &attn_weight,
const at::Tensor &grad_output,
const int im2col_step);
std::vector<at::Tensor> ms_deform_attn_cuda_backward_bf16(
const at::Tensor &value,
const at::Tensor &spatial_shapes,
const at::Tensor &level_start_index,
const at::Tensor &sampling_loc,
const at::Tensor &attn_weight,
const at::Tensor &grad_output,
const int im2col_step);
......@@ -1758,7 +1758,6 @@ class DeformableDetrModel(DeformableDetrPreTrainedModel):
spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=source_flatten.device)
level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1]))
valid_ratios = torch.stack([self.get_valid_ratio(m, dtype=source_flatten.dtype) for m in masks], 1)
valid_ratios = valid_ratios.float()
# Fourth, sent source_flatten + mask_flatten + lvl_pos_embed_flatten (backbone + proj layer output) through encoder
# Also provide spatial_shapes, level_start_index and valid_ratios
......
......@@ -26,6 +26,7 @@ from transformers.testing_utils import (
require_timm,
require_torch,
require_torch_accelerator,
require_torch_bf16,
require_vision,
slow,
torch_device,
......@@ -591,6 +592,18 @@ class DeformableDetrModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineT
output = model(**inputs)["last_hidden_state"]
self.parent.assertFalse(torch.isnan(output).any().item())
@require_torch_bf16
def create_and_check_model_bf16_forward(self):
model_class = DeformableDetrForObjectDetection
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config, torch_dtype=torch.bfloat16)
model.to(torch_device)
model.eval()
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
output = model(**inputs)["last_hidden_state"]
self.parent.assertFalse(torch.isnan(output).any().item())
TOLERANCE = 1e-4
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment