Unverified Commit b73b5b06 authored by Junhao's avatar Junhao Committed by GitHub
Browse files

Make microbatch optimization (DBO) work with general models (#37926)


Signed-off-by: default avatarJunhao Li <junhao@ubicloud.com>
parent 0f0e0389
...@@ -389,16 +389,20 @@ class UBatchWrapper: ...@@ -389,16 +389,20 @@ class UBatchWrapper:
inputs_embeds, inputs_embeds,
intermediate_tensors, intermediate_tensors,
): ):
sliced_input_ids = input_ids[tokens_slice] sliced_input_ids = input_ids[tokens_slice] if input_ids is not None else None
# if we are using mrope. Mrope adds an additional dimension to the # if we are using mrope. Mrope adds an additional dimension to the
# positions tensor # positions tensor
if positions.ndim == 2: if positions.ndim == 2:
sliced_positions = positions[:, tokens_slice] sliced_positions = positions[:, tokens_slice]
else: else:
sliced_positions = positions[tokens_slice] sliced_positions = positions[tokens_slice]
sliced_inputs_embeds = inputs_embeds[tokens_slice] if inputs_embeds else None sliced_inputs_embeds = (
inputs_embeds[tokens_slice] if inputs_embeds is not None else None
)
sliced_intermediate_tensors = ( sliced_intermediate_tensors = (
intermediate_tensors[tokens_slice] if intermediate_tensors else None intermediate_tensors[tokens_slice]
if intermediate_tensors is not None
else None
) )
return ( return (
...@@ -478,7 +482,7 @@ class UBatchWrapper: ...@@ -478,7 +482,7 @@ class UBatchWrapper:
cudagraph_runtime_mode=CUDAGraphMode.NONE, cudagraph_runtime_mode=CUDAGraphMode.NONE,
) )
with self.sm_control: with self.sm_control:
return self._capture_ubatches(ubatch_metadata, self.model) return self._capture_ubatches(ubatch_metadata, self.runnable)
elif ( elif (
num_tokens in self.cudagraphs num_tokens in self.cudagraphs
and cudagraph_runtime_mode is CUDAGraphMode.FULL and cudagraph_runtime_mode is CUDAGraphMode.FULL
...@@ -504,4 +508,4 @@ class UBatchWrapper: ...@@ -504,4 +508,4 @@ class UBatchWrapper:
cudagraph_runtime_mode=CUDAGraphMode.NONE, cudagraph_runtime_mode=CUDAGraphMode.NONE,
) )
with self.sm_control: with self.sm_control:
return self._run_ubatches(ubatch_metadata, self.model) return self._run_ubatches(ubatch_metadata, self.runnable)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment