Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
Bw-bestperf
Qwen2.5-14B-Instruct_dcu-megatron
Commits
f356f546
Commit
f356f546
authored
Feb 04, 2026
by
maming
Browse files
Initial commit
parents
Pipeline
#3339
canceled with stages
Changes
346
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
4208 additions
and
0 deletions
+4208
-0
Megatron-Energon/src/megatron/energon/transforms/common.py
Megatron-Energon/src/megatron/energon/transforms/common.py
+15
-0
Megatron-Energon/src/megatron/energon/transforms/custom.py
Megatron-Energon/src/megatron/energon/transforms/custom.py
+12
-0
Megatron-Energon/src/megatron/energon/transforms/mappers.py
Megatron-Energon/src/megatron/energon/transforms/mappers.py
+411
-0
Megatron-Energon/src/megatron/energon/transforms/merge.py
Megatron-Energon/src/megatron/energon/transforms/merge.py
+163
-0
Megatron-Energon/src/megatron/energon/typed_converter.py
Megatron-Energon/src/megatron/energon/typed_converter.py
+1154
-0
Megatron-Energon/src/megatron/energon/watchdog.py
Megatron-Energon/src/megatron/energon/watchdog.py
+318
-0
Megatron-Energon/src/megatron/energon/worker.py
Megatron-Energon/src/megatron/energon/worker.py
+278
-0
Megatron-Energon/src/megatron/energon/wrappers/__init__.py
Megatron-Energon/src/megatron/energon/wrappers/__init__.py
+46
-0
Megatron-Energon/src/megatron/energon/wrappers/_log_exception.py
...n-Energon/src/megatron/energon/wrappers/_log_exception.py
+28
-0
Megatron-Energon/src/megatron/energon/wrappers/base.py
Megatron-Energon/src/megatron/energon/wrappers/base.py
+192
-0
Megatron-Energon/src/megatron/energon/wrappers/batch_dataset.py
...on-Energon/src/megatron/energon/wrappers/batch_dataset.py
+286
-0
Megatron-Energon/src/megatron/energon/wrappers/blend_dataset.py
...on-Energon/src/megatron/energon/wrappers/blend_dataset.py
+122
-0
Megatron-Energon/src/megatron/energon/wrappers/buffer.py
Megatron-Energon/src/megatron/energon/wrappers/buffer.py
+154
-0
Megatron-Energon/src/megatron/energon/wrappers/concat_dataset.py
...n-Energon/src/megatron/energon/wrappers/concat_dataset.py
+51
-0
Megatron-Energon/src/megatron/energon/wrappers/epochize_dataset.py
...Energon/src/megatron/energon/wrappers/epochize_dataset.py
+122
-0
Megatron-Energon/src/megatron/energon/wrappers/filter_dataset.py
...n-Energon/src/megatron/energon/wrappers/filter_dataset.py
+79
-0
Megatron-Energon/src/megatron/energon/wrappers/gc_dataset.py
Megatron-Energon/src/megatron/energon/wrappers/gc_dataset.py
+130
-0
Megatron-Energon/src/megatron/energon/wrappers/group_batch_dataset.py
...rgon/src/megatron/energon/wrappers/group_batch_dataset.py
+298
-0
Megatron-Energon/src/megatron/energon/wrappers/iter_map_dataset.py
...Energon/src/megatron/energon/wrappers/iter_map_dataset.py
+224
-0
Megatron-Energon/src/megatron/energon/wrappers/limit_dataset.py
...on-Energon/src/megatron/energon/wrappers/limit_dataset.py
+125
-0
No files found.
Too many changes to show.
To preserve performance only
346 of 346+
files are displayed.
Plain diff
Email patch
Megatron-Energon/src/megatron/energon/transforms/common.py
0 → 100644
View file @
f356f546
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from
enum
import
Enum
from
typing
import
TypeVar
import
PIL.Image
import
torch
TTransform
=
TypeVar
(
"TTransform"
,
bound
=
torch
.
nn
.
Module
)
class
Sampling
(
Enum
):
BILINEAR
=
PIL
.
Image
.
BILINEAR
NEAREST
=
PIL
.
Image
.
NEAREST
Megatron-Energon/src/megatron/energon/transforms/custom.py
0 → 100644
View file @
f356f546
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
import
torch
from
megatron.energon.transforms.mappers
import
TransformMapper
class
CustomTransform
(
TransformMapper
[
torch
.
nn
.
Module
]):
"""Abstract class for additional custom transforms.
Inherit from this and override at least `apply_transform`.
"""
Megatron-Energon/src/megatron/energon/transforms/mappers.py
0 → 100644
View file @
f356f546
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
import
math
import
numbers
from
abc
import
ABC
,
abstractmethod
from
typing
import
Any
,
ClassVar
,
Generic
,
List
,
Optional
,
Tuple
,
Type
,
Union
import
numpy
as
np
import
torch
from
torchvision
import
transforms
as
T
from
megatron.energon.transforms.common
import
TTransform
class
TransformMapper
(
ABC
,
Generic
[
TTransform
]):
source_type
:
ClassVar
[
Type
[
torch
.
nn
.
Module
]]
transform
:
TTransform
def
__init__
(
self
,
transform
:
TTransform
):
self
.
transform
=
transform
def
__call__
(
self
,
sample
):
return
self
.
transform
(
sample
)
@
staticmethod
def
translate
(
x
:
float
,
y
:
float
)
->
np
.
ndarray
:
m
=
np
.
eye
(
3
,
dtype
=
np
.
float64
)
m
[
0
,
2
]
=
x
m
[
1
,
2
]
=
y
return
m
@
staticmethod
def
rotate
(
angle
:
float
)
->
np
.
ndarray
:
"""Counter-clockwise rotation. Note that the Y-axis is point down."""
m
=
np
.
eye
(
3
,
dtype
=
np
.
float64
)
m
[:
2
,
:
2
]
=
np
.
array
([[
np
.
cos
(
angle
),
np
.
sin
(
angle
)],
[
-
np
.
sin
(
angle
),
np
.
cos
(
angle
)]])
return
m
@
staticmethod
def
scale
(
x
:
float
,
y
:
float
)
->
np
.
ndarray
:
m
=
np
.
eye
(
3
,
dtype
=
np
.
float64
)
m
[
0
,
0
]
=
x
m
[
1
,
1
]
=
y
return
m
@
staticmethod
def
shear
(
x
:
float
,
y
:
float
)
->
np
.
ndarray
:
m
=
np
.
eye
(
3
,
dtype
=
np
.
float64
)
m
[
0
,
1
]
=
x
m
[
1
,
0
]
=
y
return
m
@
staticmethod
def
hflip
()
->
np
.
ndarray
:
m
=
np
.
eye
(
3
,
dtype
=
np
.
float64
)
m
[
0
,
0
]
=
-
1
return
m
@
staticmethod
def
vflip
()
->
np
.
ndarray
:
m
=
np
.
eye
(
3
,
dtype
=
np
.
float64
)
m
[
1
,
1
]
=
-
1
return
m
@
abstractmethod
def
apply_transform
(
self
,
matrix
:
np
.
ndarray
,
dst_size
:
np
.
ndarray
)
->
Tuple
[
np
.
ndarray
,
np
.
ndarray
,
Any
]:
...
def
fill
(
self
,
)
->
Optional
[
Union
[
int
,
float
,
Tuple
[
Union
[
int
,
float
],
...],
List
[
Union
[
int
,
float
]]]]:
return
None
def
interpolation
(
self
)
->
Optional
[
T
.
InterpolationMode
]:
return
None
class
ResizeMapper
(
TransformMapper
[
T
.
Resize
]):
source_type
=
T
.
Resize
def
__init__
(
self
,
transform
:
T
.
Resize
):
super
().
__init__
(
transform
)
def
_compute_resized_output_size
(
self
,
image_size
:
Tuple
[
int
,
int
],
size
:
List
[
int
],
max_size
:
Optional
[
int
]
=
None
)
->
List
[
int
]:
if
len
(
size
)
==
1
:
# specified size only for the smallest edge
h
,
w
=
image_size
short
,
long
=
(
w
,
h
)
if
w
<=
h
else
(
h
,
w
)
requested_new_short
=
size
[
0
]
new_short
,
new_long
=
requested_new_short
,
int
(
requested_new_short
*
long
/
short
)
if
max_size
is
not
None
:
if
max_size
<=
requested_new_short
:
raise
ValueError
(
f
"max_size =
{
max_size
}
must be strictly greater than the requested "
f
"size for the smaller edge size =
{
size
}
"
)
if
new_long
>
max_size
:
new_short
,
new_long
=
int
(
max_size
*
new_short
/
new_long
),
max_size
new_w
,
new_h
=
(
new_short
,
new_long
)
if
w
<=
h
else
(
new_long
,
new_short
)
else
:
# specified both h and w
new_w
,
new_h
=
size
[
1
],
size
[
0
]
return
[
new_h
,
new_w
]
def
apply_transform
(
self
,
matrix
:
np
.
ndarray
,
dst_size
:
np
.
ndarray
)
->
Tuple
[
np
.
ndarray
,
np
.
ndarray
,
Tuple
[
Any
,
...]]:
size
=
self
.
transform
.
size
if
isinstance
(
size
,
int
):
size
=
[
size
]
h
,
w
=
self
.
_compute_resized_output_size
(
dst_size
,
size
,
self
.
transform
.
max_size
)
matrix
=
self
.
scale
(
w
/
dst_size
[
1
],
h
/
dst_size
[
0
])
@
matrix
# matrix = self.scale((w - 1) / (dst_size[1] - 1), (h - 1) / (dst_size[0] - 1)) @ matrix
# matrix = self.translate(0.25, 0.25) @ matrix
# matrix = self.translate(0.1, 0) @ matrix
dst_size
=
np
.
array
((
h
,
w
),
dtype
=
dst_size
.
dtype
)
# print(f"Resize s={size}")
return
matrix
,
dst_size
,
(
self
.
source_type
.
__name__
,
size
)
def
interpolation
(
self
)
->
Optional
[
T
.
InterpolationMode
]:
return
self
.
transform
.
interpolation
class
RandomResizedCropMapper
(
TransformMapper
[
T
.
RandomResizedCrop
]):
source_type
=
T
.
RandomResizedCrop
def
__init__
(
self
,
transform
:
T
.
RandomResizedCrop
):
super
().
__init__
(
transform
)
def
get_params
(
self
,
size
:
Tuple
[
int
,
int
])
->
Tuple
[
int
,
int
,
int
,
int
]:
"""
Gets the parameters for a random resized crop.
This function is derived from T.RandomResizedCrop.get_params, but without requiring the
input image (to determine the input size).
Returns:
Tuple of (top, left, height, width).
"""
height
,
width
=
size
area
=
height
*
width
log_ratio
=
torch
.
log
(
torch
.
tensor
(
self
.
transform
.
ratio
))
for
_
in
range
(
10
):
target_area
=
(
area
*
torch
.
empty
(
1
).
uniform_
(
self
.
transform
.
scale
[
0
],
self
.
transform
.
scale
[
1
]).
item
()
)
aspect_ratio
=
torch
.
exp
(
torch
.
empty
(
1
).
uniform_
(
log_ratio
[
0
],
log_ratio
[
1
])).
item
()
w
=
int
(
round
(
math
.
sqrt
(
target_area
*
aspect_ratio
)))
h
=
int
(
round
(
math
.
sqrt
(
target_area
/
aspect_ratio
)))
if
0
<
w
<=
width
and
0
<
h
<=
height
:
i
=
torch
.
randint
(
0
,
height
-
h
+
1
,
size
=
(
1
,)).
item
()
j
=
torch
.
randint
(
0
,
width
-
w
+
1
,
size
=
(
1
,)).
item
()
return
i
,
j
,
h
,
w
# Fallback to central crop
in_ratio
=
float
(
width
)
/
float
(
height
)
if
in_ratio
<
min
(
self
.
transform
.
ratio
):
w
=
width
h
=
int
(
round
(
w
/
min
(
self
.
transform
.
ratio
)))
elif
in_ratio
>
max
(
self
.
transform
.
ratio
):
h
=
height
w
=
int
(
round
(
h
*
max
(
self
.
transform
.
ratio
)))
else
:
# whole image
w
=
width
h
=
height
i
=
(
height
-
h
)
//
2
j
=
(
width
-
w
)
//
2
return
i
,
j
,
h
,
w
def
apply_transform
(
self
,
matrix
:
np
.
ndarray
,
dst_size
:
np
.
ndarray
)
->
Tuple
[
np
.
ndarray
,
np
.
ndarray
,
Tuple
[
Any
,
...]]:
top
,
left
,
height
,
width
=
self
.
get_params
(
dst_size
)
# print(
# "RandomResizedCrop", top, left, dst_size[0] - height - top, dst_size[1] - width - left
# )
# Crop to left, top, height, width
matrix
=
self
.
translate
(
-
left
,
-
top
)
@
matrix
dst_size
=
np
.
array
([
height
,
width
],
dtype
=
dst_size
.
dtype
)
# Resize to target size
matrix
=
(
self
.
scale
(
self
.
transform
.
size
[
1
]
/
dst_size
[
1
],
self
.
transform
.
size
[
0
]
/
dst_size
[
0
])
@
matrix
)
dst_size
=
np
.
array
(
self
.
transform
.
size
,
dtype
=
dst_size
.
dtype
)
return
matrix
,
dst_size
,
(
self
.
source_type
.
__name__
,
(
top
,
left
,
height
,
width
))
def
interpolation
(
self
)
->
Optional
[
T
.
InterpolationMode
]:
return
self
.
transform
.
interpolation
class
RandomHorizontalFlipMapper
(
TransformMapper
[
T
.
RandomHorizontalFlip
]):
source_type
=
T
.
RandomHorizontalFlip
def
__init__
(
self
,
transform
:
T
.
RandomHorizontalFlip
):
super
().
__init__
(
transform
)
def
apply_transform
(
self
,
matrix
:
np
.
ndarray
,
dst_size
:
np
.
ndarray
)
->
Tuple
[
np
.
ndarray
,
np
.
ndarray
,
Any
]:
do_flip
=
torch
.
rand
(
1
)
<
self
.
transform
.
p
if
do_flip
:
matrix
=
self
.
hflip
()
@
matrix
matrix
=
self
.
translate
(
dst_size
[
1
],
0
)
@
matrix
# print(f"RandomHorizontalFlip")
return
matrix
,
dst_size
,
(
self
.
source_type
.
__name__
,
do_flip
)
class
RandomVerticalFlipMapper
(
TransformMapper
[
T
.
RandomVerticalFlip
]):
source_type
=
T
.
RandomVerticalFlip
def
__init__
(
self
,
transform
:
T
.
RandomVerticalFlip
):
super
().
__init__
(
transform
)
def
apply_transform
(
self
,
matrix
:
np
.
ndarray
,
dst_size
:
np
.
ndarray
)
->
Tuple
[
np
.
ndarray
,
np
.
ndarray
,
Any
]:
do_flip
=
torch
.
rand
(
1
)
<
self
.
transform
.
p
if
do_flip
:
matrix
=
self
.
vflip
()
@
matrix
matrix
=
self
.
translate
(
0
,
dst_size
[
0
])
@
matrix
# print(f"RandomVerticalFlip")
return
matrix
,
dst_size
,
(
self
.
source_type
.
__name__
,
do_flip
)
class
RandomRotationMapper
(
TransformMapper
[
T
.
RandomRotation
]):
source_type
=
T
.
RandomRotation
def
__init__
(
self
,
transform
:
T
.
RandomRotation
):
super
().
__init__
(
transform
)
def
apply_transform
(
self
,
matrix
:
np
.
ndarray
,
dst_size
:
np
.
ndarray
)
->
Tuple
[
np
.
ndarray
,
np
.
ndarray
,
Any
]:
assert
self
.
transform
.
center
is
None
,
"Only centered rotation is supported"
degrees
=
self
.
transform
.
get_params
(
self
.
transform
.
degrees
)
rads
=
degrees
*
np
.
pi
/
180
# print(f"Rotate deg={degrees}")
orig_size
=
dst_size
if
self
.
transform
.
expand
:
# Compute size of rotated rectangle
w
=
np
.
abs
(
np
.
sin
(
rads
))
*
dst_size
[
0
]
+
np
.
abs
(
np
.
cos
(
rads
))
*
dst_size
[
1
]
h
=
np
.
abs
(
np
.
sin
(
rads
))
*
dst_size
[
1
]
+
np
.
abs
(
np
.
cos
(
rads
))
*
dst_size
[
0
]
# Round in the same way as PIL does
rounded_w
=
np
.
ceil
(
orig_size
[
1
]
/
2
+
w
/
2
)
-
np
.
floor
(
orig_size
[
1
]
/
2
-
w
/
2
)
rounded_h
=
np
.
ceil
(
orig_size
[
0
]
/
2
+
h
/
2
)
-
np
.
floor
(
orig_size
[
0
]
/
2
-
h
/
2
)
# New size is h, w
dst_size
=
np
.
array
([
int
(
rounded_h
),
int
(
rounded_w
)],
dtype
=
dst_size
.
dtype
)
matrix
=
(
self
.
translate
(
dst_size
[
1
]
/
2
,
dst_size
[
0
]
/
2
)
@
self
.
rotate
(
rads
)
@
self
.
translate
(
-
orig_size
[
1
]
/
2
,
-
orig_size
[
0
]
/
2
)
@
matrix
)
return
matrix
,
dst_size
,
(
self
.
source_type
.
__name__
,
degrees
)
def
fill
(
self
,
)
->
Optional
[
Union
[
int
,
float
,
Tuple
[
Union
[
int
,
float
],
...],
List
[
Union
[
int
,
float
]]]]:
return
self
.
transform
.
fill
def
interpolation
(
self
)
->
Optional
[
T
.
InterpolationMode
]:
return
self
.
transform
.
interpolation
class
RandomCropMapper
(
TransformMapper
[
T
.
RandomCrop
]):
source_type
=
T
.
RandomCrop
def
__init__
(
self
,
transform
:
T
.
RandomCrop
):
super
().
__init__
(
transform
)
def
apply_transform
(
self
,
matrix
:
np
.
ndarray
,
dst_size
:
np
.
ndarray
)
->
Tuple
[
np
.
ndarray
,
np
.
ndarray
,
Any
]:
th
,
tw
=
self
.
transform
.
size
# Target height and width
# pad the width if needed
if
self
.
transform
.
pad_if_needed
and
dst_size
[
1
]
<
tw
:
padding
=
tw
-
dst_size
[
1
]
# Pad this much on both left and right
matrix
=
self
.
translate
(
padding
,
0
)
@
matrix
dst_size
[
1
]
+=
2
*
padding
# pad the height if needed
if
self
.
transform
.
pad_if_needed
and
dst_size
[
0
]
<
th
:
padding
=
th
-
dst_size
[
0
]
# Pad this much on both top and bottom
matrix
=
self
.
translate
(
0
,
padding
)
@
matrix
dst_size
[
0
]
+=
2
*
padding
h
,
w
=
dst_size
if
h
<
th
or
w
<
tw
:
raise
ValueError
(
f
"Required crop size
{
(
th
,
tw
)
}
is larger than input image size
{
(
h
,
w
)
}
"
)
if
w
==
tw
and
h
==
th
:
# No need to crop if we're at the target size already
i
=
0
j
=
0
else
:
i
=
torch
.
randint
(
0
,
h
-
th
+
1
,
size
=
(
1
,)).
item
()
# Offset y
j
=
torch
.
randint
(
0
,
w
-
tw
+
1
,
size
=
(
1
,)).
item
()
# Offset x
matrix
=
self
.
translate
(
-
j
,
-
i
)
@
matrix
if
self
.
transform
.
pad_if_needed
:
dst_size
=
np
.
array
((
th
,
tw
),
dtype
=
dst_size
.
dtype
)
else
:
dst_size
=
np
.
array
((
min
(
th
,
dst_size
[
0
]),
min
(
tw
,
dst_size
[
1
])),
dtype
=
dst_size
.
dtype
)
# print(f"RandomCrop t=[{dx}, {dy}], s={dst_size}")
return
matrix
,
dst_size
,
(
self
.
source_type
.
__name__
,
(
j
,
i
,
th
,
tw
))
def
fill
(
self
,
)
->
Optional
[
Union
[
int
,
float
,
Tuple
[
Union
[
int
,
float
],
...],
List
[
Union
[
int
,
float
]]]]:
return
self
.
transform
.
fill
class
RandomPerspectiveMapper
(
TransformMapper
[
T
.
RandomPerspective
]):
source_type
=
T
.
RandomPerspective
def
__init__
(
self
,
transform
:
T
.
RandomPerspective
):
super
().
__init__
(
transform
)
@
staticmethod
def
compute_homography
(
startpoints
:
List
[
Tuple
[
float
,
float
]],
endpoints
:
List
[
Tuple
[
float
,
float
]]
)
->
np
.
ndarray
:
assert
len
(
startpoints
)
==
len
(
endpoints
)
==
4
a_matrix
=
torch
.
zeros
(
2
*
len
(
startpoints
),
8
,
dtype
=
torch
.
float
)
for
i
,
(
p1
,
p2
)
in
enumerate
(
zip
(
endpoints
,
startpoints
)):
a_matrix
[
2
*
i
,
:]
=
torch
.
tensor
(
[
p1
[
0
],
p1
[
1
],
1
,
0
,
0
,
0
,
-
p2
[
0
]
*
p1
[
0
],
-
p2
[
0
]
*
p1
[
1
]]
)
a_matrix
[
2
*
i
+
1
,
:]
=
torch
.
tensor
(
[
0
,
0
,
0
,
p1
[
0
],
p1
[
1
],
1
,
-
p2
[
1
]
*
p1
[
0
],
-
p2
[
1
]
*
p1
[
1
]]
)
b_matrix
=
torch
.
tensor
(
startpoints
,
dtype
=
torch
.
float
).
view
(
8
)
res
=
torch
.
linalg
.
lstsq
(
a_matrix
,
b_matrix
,
driver
=
"gels"
).
solution
m
=
np
.
eye
(
3
,
dtype
=
np
.
float32
)
m
[
0
,
:]
=
res
[:
3
]
m
[
1
,
:]
=
res
[
3
:
6
]
m
[
2
,
:
2
]
=
res
[
6
:]
return
m
def
apply_transform
(
self
,
matrix
:
np
.
ndarray
,
dst_size
:
np
.
ndarray
)
->
Tuple
[
np
.
ndarray
,
np
.
ndarray
,
Any
]:
assert
self
.
transform
.
fill
==
0
,
"Only zero fill is supported"
startpoints
=
None
endpoints
=
None
if
torch
.
rand
(
1
)
<=
self
.
transform
.
p
:
startpoints
,
endpoints
=
self
.
transform
.
get_params
(
dst_size
[
1
],
dst_size
[
0
],
self
.
transform
.
distortion_scale
)
# print(
# f"Perspective ds={self.transform.distortion_scale}: sp={startpoints} -> ep={endpoints}"
# )
matrix
=
self
.
compute_homography
(
endpoints
,
startpoints
)
@
matrix
return
matrix
,
dst_size
,
(
self
.
source_type
.
__name__
,
startpoints
,
endpoints
)
def
fill
(
self
,
)
->
Optional
[
Union
[
int
,
float
,
Tuple
[
Union
[
int
,
float
],
...],
List
[
Union
[
int
,
float
]]]]:
return
self
.
transform
.
fill
def
interpolation
(
self
)
->
Optional
[
T
.
InterpolationMode
]:
return
self
.
transform
.
interpolation
class
CenterCropMapper
(
TransformMapper
[
T
.
CenterCrop
]):
source_type
=
T
.
CenterCrop
def
__init__
(
self
,
transform
:
T
.
CenterCrop
):
super
().
__init__
(
transform
)
def
apply_transform
(
self
,
matrix
:
np
.
ndarray
,
dst_size
:
np
.
ndarray
)
->
Tuple
[
np
.
ndarray
,
np
.
ndarray
,
Any
]:
if
isinstance
(
self
.
transform
.
size
,
numbers
.
Number
):
th
,
tw
=
int
(
self
.
transform
.
size
),
int
(
self
.
transform
.
size
)
elif
isinstance
(
self
.
transform
.
size
,
(
tuple
,
list
))
and
len
(
self
.
transform
.
size
)
==
1
:
th
,
tw
=
self
.
transform
.
size
[
0
],
self
.
transform
.
size
[
0
]
else
:
th
,
tw
=
self
.
transform
.
size
shift_y
=
round
(
float
(
th
-
dst_size
[
0
])
/
2
)
shift_x
=
round
(
float
(
tw
-
dst_size
[
1
])
/
2
)
matrix
=
self
.
translate
(
shift_x
,
shift_y
)
@
matrix
dst_size
=
np
.
array
((
th
,
tw
),
dtype
=
dst_size
.
dtype
)
# print(f"CenterCrop t=[{dx}, {dy}], s={dst_size}")
return
matrix
,
dst_size
,
(
self
.
source_type
.
__name__
,
(
shift_y
,
shift_x
,
th
,
tw
))
Megatron-Energon/src/megatron/energon/transforms/merge.py
0 → 100644
View file @
f356f546
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from
typing
import
List
,
Type
import
numpy
as
np
import
PIL.Image
import
torch
from
torchvision
import
transforms
as
T
from
torchvision.transforms.functional
import
InterpolationMode
,
pil_modes_mapping
from
megatron.energon.transforms.custom
import
CustomTransform
from
megatron.energon.transforms.mappers
import
TransformMapper
class
MergeTransform
(
torch
.
nn
.
Module
):
transforms
:
List
[
TransformMapper
]
def
__init__
(
self
,
transforms
:
List
[
torch
.
nn
.
Module
]):
super
().
__init__
()
self
.
transforms
=
[
self
.
_transform_to_mapper
(
t
)
for
t
in
transforms
]
# Get common interpolation and fill values
interpolation
=
None
interpolation_t
=
None
fill
=
None
fill_t
=
None
for
t
in
self
.
transforms
:
t_fill
=
t
.
fill
()
if
t_fill
is
not
None
:
if
fill
is
None
:
fill
=
t_fill
fill_t
=
t
if
fill
is
not
None
and
t_fill
!=
fill
:
raise
ValueError
(
f
"Fill values are not equal:
{
fill
}
(
{
fill_t
}
) !=
{
t_fill
}
(
{
t
}
)"
)
t_interpolation
=
t
.
interpolation
()
if
t_interpolation
is
not
None
:
if
interpolation
is
None
:
interpolation
=
t_interpolation
interpolation_t
=
t
if
interpolation
is
not
None
and
t_interpolation
!=
interpolation
:
raise
ValueError
(
f
"Interpolation values are not equal:
{
interpolation
}
(
{
interpolation_t
}
) !=
{
t_interpolation
}
(
{
t
}
)"
)
self
.
interpolation
=
InterpolationMode
.
BILINEAR
if
interpolation
is
None
else
interpolation
self
.
fill_value
=
fill
def
_transform_to_mapper
(
self
,
transform
:
torch
.
nn
.
Module
)
->
Type
[
TransformMapper
]:
"""Given a transform object, instantiate the corresponding mapper.
This also handles objects of derived transform classes."""
if
isinstance
(
transform
,
CustomTransform
):
# Custom transforms can be used as-is, they provide the apply_transform method
return
transform
for
m
in
TransformMapper
.
__subclasses__
():
if
isinstance
(
transform
,
m
.
source_type
):
return
m
(
transform
)
# Instantiate
raise
ValueError
(
f
"Unsupported transform type
{
type
(
transform
)
}
"
)
def
forward
(
self
,
x
):
matrix
=
np
.
eye
(
3
,
dtype
=
np
.
float64
)
if
isinstance
(
x
,
PIL
.
Image
.
Image
):
dst_size
=
np
.
array
((
x
.
height
,
x
.
width
),
dtype
=
np
.
int64
)
else
:
dst_size
=
np
.
array
(
x
.
shape
[
-
2
:],
dtype
=
np
.
int64
)
all_params
=
[]
for
transform
in
self
.
transforms
:
matrix
,
dst_size
,
params
=
transform
.
apply_transform
(
matrix
,
dst_size
)
all_params
.
append
(
params
)
if
isinstance
(
x
,
PIL
.
Image
.
Image
):
try
:
interpolation
=
pil_modes_mapping
[
self
.
interpolation
]
except
KeyError
:
raise
NotImplementedError
(
f
"interpolation:
{
self
.
interpolation
}
"
)
# Invert matrix for backward mapping
matrix
=
np
.
linalg
.
inv
(
matrix
)
# Scale matrix
matrix
/=
matrix
[
2
,
2
]
if
self
.
fill_value
is
None
:
fill_color
=
None
elif
isinstance
(
self
.
fill_value
,
(
int
,
float
)):
fill_color
=
(
self
.
fill_value
,)
*
len
(
x
.
getbands
())
else
:
fill_color
=
self
.
fill_value
if
np
.
allclose
(
matrix
[
2
,
:
2
],
[
0
,
0
]):
# print("PIL Affine")
return
x
.
transform
(
tuple
(
dst_size
[::
-
1
]),
PIL
.
Image
.
AFFINE
,
matrix
.
flatten
()[:
6
],
interpolation
,
fillcolor
=
fill_color
,
)
else
:
# print("PIL Perspective")
return
x
.
transform
(
tuple
(
dst_size
[::
-
1
]),
PIL
.
Image
.
PERSPECTIVE
,
matrix
.
flatten
()[:
8
],
interpolation
,
fillcolor
=
fill_color
,
)
elif
isinstance
(
x
,
torch
.
Tensor
):
print
(
"torch affine"
)
if
self
.
interpolation
==
T
.
InterpolationMode
.
NEAREST
:
interpolation
=
"nearest"
elif
self
.
interpolation
==
T
.
InterpolationMode
.
BILINEAR
:
interpolation
=
"bilinear"
elif
self
.
interpolation
==
T
.
InterpolationMode
.
BICUBIC
:
interpolation
=
"bicubic"
else
:
raise
NotImplementedError
(
f
"interpolation:
{
self
.
interpolation
}
"
)
if
self
.
fill_value
is
not
None
and
self
.
fill_value
!=
0
:
raise
NotImplementedError
(
f
"Fill value
{
self
.
fill_value
}
is not supported for torch"
)
# Normalize to [-1, 1] range
matrix
=
(
TransformMapper
.
translate
(
-
1
,
-
1
)
@
TransformMapper
.
scale
(
2
/
dst_size
[
1
],
2
/
dst_size
[
0
])
@
matrix
@
TransformMapper
.
scale
(
x
.
shape
[
-
1
]
/
2
,
x
.
shape
[
-
2
]
/
2
)
@
TransformMapper
.
translate
(
1
,
1
)
)
matrix
=
np
.
linalg
.
inv
(
matrix
)
if
np
.
allclose
(
matrix
[
2
,
:
2
],
[
0
,
0
]):
grid
=
torch
.
nn
.
functional
.
affine_grid
(
torch
.
as_tensor
(
matrix
[
None
,
:
2
,
:],
dtype
=
torch
.
float32
),
torch
.
Size
((
1
,
3
,
*
dst_size
)),
)
else
:
xs
=
torch
.
linspace
(
-
1
,
1
,
dst_size
[
1
],
dtype
=
torch
.
float32
)
ys
=
torch
.
linspace
(
-
1
,
1
,
dst_size
[
0
],
dtype
=
torch
.
float32
)
zs
=
torch
.
ones
((
1
,),
dtype
=
torch
.
float32
)
# shape: (2<x,y,1>, W, H)
grid
=
torch
.
stack
(
torch
.
meshgrid
([
xs
,
ys
,
zs
],
indexing
=
"ij"
))[...,
0
]
# shape: (H, W, 2<x,y,1>)
grid
=
grid
.
permute
(
2
,
1
,
0
)
# shape: (H, W, 3<x,y,w>, 1)
grid
=
(
torch
.
as_tensor
(
matrix
,
dtype
=
torch
.
float32
)[
None
,
None
,
...]
@
grid
[...,
None
]
)
# shape: (H, W, 2<x,y>)
grid
=
grid
[:,
:,
:
2
,
0
]
/
grid
[:,
:,
2
:
3
,
0
]
# shape: (1, H, W, 2<x,y>)
grid
=
grid
[
None
,
...]
return
torch
.
nn
.
functional
.
grid_sample
(
x
[
None
,
...],
grid
,
interpolation
,
padding_mode
=
"zeros"
,
align_corners
=
False
)[
0
,
...]
else
:
raise
NotImplementedError
()
# TODO: Needs implementation and testing
import
cv2
return
cv2
.
warpAffine
(
x
,
matrix
[:
2
],
tuple
(
dst_size
),
flags
=
cv2
.
INTER_LINEAR
)
Megatron-Energon/src/megatron/energon/typed_converter.py
0 → 100644
View file @
f356f546
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
"""Provides methods for converting typed objects to json objects and vice versa."""
import
dataclasses
import
importlib
import
inspect
import
json
import
re
import
typing
from
typing
import
Any
,
Callable
,
Dict
,
Literal
,
Optional
,
Tuple
,
Type
,
TypeVar
,
Union
from
megatron.energon.edataclass
import
edataclass
from
megatron.energon.epathlib
import
EPath
class
JsonValueError
(
ValueError
):
"""Internal error class, for better error messages."""
def
__init__
(
self
,
msg
:
str
,
expected_type
:
Type
,
value
:
Any
,
path
:
str
,
stage
:
Tuple
[
int
,
...]
)
->
None
:
"""New error"""
super
().
__init__
(
msg
)
self
.
expected_type
=
expected_type
self
.
value
=
value
self
.
path
=
path
self
.
stage
=
stage
class
_MissingValue
:
"""Internal class representing a missing value (i.e. a field which has not been specified,
opposed to explicit `None`) for better error messages."""
def
__repr__
(
self
)
->
str
:
return
"missing value"
_missing_value
=
_MissingValue
()
TType
=
TypeVar
(
"TType"
)
def
_check_instance_type
(
cls
,
inst_type
:
Type
)
->
bool
:
"""Check if a class is an instance of a type."""
if
inst_type
is
None
:
return
True
elif
typing
.
get_origin
(
inst_type
)
is
not
None
:
org
=
typing
.
get_origin
(
inst_type
)
if
org
is
Union
:
for
check_type
in
typing
.
get_args
(
inst_type
):
if
_check_instance_type
(
cls
,
check_type
):
return
True
elif
isinstance
(
org
,
type
)
and
issubclass
(
cls
,
org
):
return
True
return
False
elif
inst_type
is
Any
:
return
True
else
:
return
not
issubclass
(
inst_type
,
type
)
or
issubclass
(
cls
,
inst_type
)
@
edataclass
class
JsonParser
:
strict
:
bool
=
True
def
_resolve_object
(
self
,
module_name
:
str
,
object_name
:
str
,
cls
:
Type
[
TType
],
is_type
:
bool
,
is_callable
:
bool
,
is_instantiating_class
:
bool
,
is_calling_function
:
bool
,
)
->
TType
:
"""Resolve an object reference to the object."""
try
:
module
=
importlib
.
import_module
(
module_name
)
except
ModuleNotFoundError
:
try_energon_module
=
importlib
.
import_module
(
"megatron.energon"
,
package
=
None
)
if
hasattr
(
try_energon_module
,
object_name
):
module
=
try_energon_module
else
:
raise
try
:
return
getattr
(
module
,
object_name
)
except
AttributeError
:
raise
ModuleNotFoundError
(
f
"Object
{
object_name
}
not found in
{
module_name
}
"
)
def
raw_to_instance
(
self
,
kwargs
:
dict
,
inst_type
:
Type
[
TType
],
_path
:
str
=
"root"
,
_stage
:
Tuple
[
int
,
...]
=
(),
)
->
TType
:
"""
Try to import and instantiate a class from a dict with "__module__" and "__class__"/"__function__" keys.
Args:
kwargs: The dict to parse
inst_type: Expected return type, used if type is not specified in the kwargs
strict: If true, don't allow additional attributes
_path: (internal for recursive call) The path to the object being converted from the root
_stage: (internal for recursive call) Numbers representing the position of the current
object being converted from the root
Returns:
Instantiated class
"""
kwargs
=
kwargs
.
copy
()
module_name
=
kwargs
.
pop
(
"__module__"
,
None
)
# Check if this is a type of Type[...] or just a class. Type[...] will return the class instead
# of instantiating it.
is_type
=
typing
.
get_origin
(
inst_type
)
is
type
is_callable
=
typing
.
get_origin
(
inst_type
)
is
typing
.
get_origin
(
Callable
)
is_calling_function
=
False
is_instantiating_class
=
False
if
is_type
:
inst_type
=
typing
.
get_args
(
inst_type
)[
0
]
object_name
=
kwargs
.
pop
(
"__class__"
,
None
)
if
module_name
is
None
or
object_name
is
None
:
raise
JsonValueError
(
f
"Expected __module__ and __class__ for Type[
{
inst_type
}
], got
{
kwargs
}
"
,
inst_type
,
(
module_name
,
object_name
),
_path
,
_stage
,
)
elif
is_callable
:
object_name
=
kwargs
.
pop
(
"__function__"
,
None
)
if
module_name
is
None
or
object_name
is
None
:
raise
JsonValueError
(
f
"Expected __module__ and __function__ for
{
inst_type
}
, got
{
kwargs
}
"
,
inst_type
,
(
module_name
,
object_name
),
_path
,
_stage
,
)
else
:
if
"__class__"
in
kwargs
:
object_name
=
kwargs
.
pop
(
"__class__"
,
None
)
is_instantiating_class
=
True
is_calling_function
=
False
elif
"__function__"
in
kwargs
:
object_name
=
kwargs
.
pop
(
"__function__"
,
None
)
is_instantiating_class
=
False
is_calling_function
=
True
# Else case: It's a plain type, and nothing was passed, use the default cls
if
module_name
is
None
or
object_name
is
None
:
cls
=
inst_type
else
:
cls
=
self
.
_resolve_object
(
module_name
,
object_name
,
inst_type
,
is_type
,
is_callable
,
is_instantiating_class
,
is_calling_function
,
)
if
is_type
:
if
isinstance
(
inst_type
,
type
)
and
(
not
isinstance
(
cls
,
type
)
or
not
issubclass
(
cls
,
inst_type
)
):
raise
JsonValueError
(
f
"Expected Type[
{
inst_type
}
], got
{
cls
}
"
,
inst_type
,
cls
,
_path
,
_stage
)
elif
is_callable
:
if
not
callable
(
cls
):
raise
JsonValueError
(
f
"Expected a callable, got
{
cls
}
"
,
inst_type
,
cls
,
_path
,
_stage
)
elif
is_instantiating_class
:
if
not
isinstance
(
cls
,
type
)
or
not
_check_instance_type
(
cls
,
inst_type
):
raise
JsonValueError
(
f
"Expected
{
inst_type
}
, got
{
cls
}
"
,
inst_type
,
cls
,
_path
,
_stage
)
else
:
assert
is_calling_function
if
not
callable
(
cls
):
raise
JsonValueError
(
f
"Expected
{
inst_type
}
, got
{
cls
}
"
,
inst_type
,
cls
,
_path
,
_stage
)
if
is_type
or
is_callable
:
inst
=
cls
else
:
# Do not assert the other cases, we fallback to the passed cls
inst
=
self
.
safe_call_function
(
kwargs
,
cls
,
allow_imports
=
True
)
assert
not
isinstance
(
cls
,
type
)
or
_check_instance_type
(
type
(
inst
),
inst_type
),
(
f
"Expected
{
inst_type
}
, got
{
cls
}
"
)
return
inst
def
raw_to_typed
(
# noqa: C901
self
,
raw_data
:
Union
[
dict
,
list
,
str
,
int
,
bool
,
float
,
None
],
inst_type
:
Type
[
TType
],
allow_imports
:
bool
=
False
,
_path
:
str
=
"root"
,
_stage
:
Tuple
[
int
,
...]
=
(),
)
->
TType
:
"""
Converts raw data (i.e. dicts, lists and primitives) to typed objects (like
`NamedTuple` or `dataclasses.dataclass`). Validates that python typing matches.
Usage::
class MyNamedTuple(NamedTuple):
x: int
y: str
assert raw_to_typed({'x': int, 'y': "foo"}, MyNamedTuple) == MyNamedTuple(x=5, y="foo")
Args:
raw_data: The raw (e.g. json) data to be made as `inst_type`
inst_type: The type to return
allow_imports: If true, parse '__module__' and '__class__/__function__' attributes to allow explicit
instantiation of types
_path: (internal for recursive call) The path to the object being converted from the root
_stage: (internal for recursive call) Numbers representing the position of the current
object being converted from the root
Returns:
The input data as `inst_type`.
"""
type_name
=
getattr
(
inst_type
,
"__name__"
,
repr
(
inst_type
))
if
raw_data
is
_missing_value
:
raise
JsonValueError
(
f
"Missing value at
{
_path
}
"
,
inst_type
,
raw_data
,
_path
,
_stage
,
)
elif
inst_type
in
(
str
,
int
,
float
,
bool
,
None
,
type
(
None
)):
# Literal types or missing data
if
not
isinstance
(
raw_data
,
inst_type
)
and
not
(
isinstance
(
raw_data
,
int
)
and
inst_type
is
float
):
raise
JsonValueError
(
f
"Type does not match, expected
{
type_name
}
at
{
_path
}
, got
{
raw_data
!
r
}
"
,
inst_type
,
raw_data
,
_path
,
_stage
,
)
return
raw_data
elif
inst_type
is
Any
:
if
(
allow_imports
and
isinstance
(
raw_data
,
dict
)
and
"__module__"
in
raw_data
and
(
"__class__"
in
raw_data
or
"__function__"
in
raw_data
)
):
return
self
.
raw_to_instance
(
raw_data
,
inst_type
,
_path
=
_path
,
_stage
=
_stage
)
# Any
return
raw_data
elif
typing
.
get_origin
(
inst_type
)
is
Literal
:
# Literal[value[, ...]]
values
=
typing
.
get_args
(
inst_type
)
if
raw_data
not
in
values
:
raise
JsonValueError
(
f
"Expected
{
type_name
}
at
{
_path
}
, got
{
raw_data
!
r
}
"
,
inst_type
,
raw_data
,
_path
,
_stage
,
)
return
raw_data
elif
typing
.
get_origin
(
inst_type
)
is
Union
:
# Union[union_types[0], union_types[1], ...]
union_types
=
typing
.
get_args
(
inst_type
)
if
None
in
union_types
:
# Fast Optional path
if
raw_data
is
None
:
return
None
best_inner_error
:
Optional
[
JsonValueError
]
=
None
inner_exceptions
=
[]
for
subtype
in
union_types
:
try
:
return
self
.
raw_to_typed
(
raw_data
,
subtype
,
allow_imports
,
f
"
{
_path
}
->
{
getattr
(
subtype
,
'__name__'
,
repr
(
subtype
))
}
"
,
_stage
+
(
1
,),
)
except
JsonValueError
as
err
:
if
best_inner_error
is
None
or
len
(
err
.
stage
)
>
len
(
best_inner_error
.
stage
):
best_inner_error
=
err
inner_exceptions
.
clear
()
inner_exceptions
.
append
(
err
)
elif
len
(
err
.
stage
)
==
len
(
best_inner_error
.
stage
):
inner_exceptions
.
append
(
err
)
continue
if
len
(
inner_exceptions
)
>
0
:
cur_exc
=
inner_exceptions
[
0
]
for
next_exc
in
inner_exceptions
[
1
:]:
try
:
raise
next_exc
from
cur_exc
except
JsonValueError
as
e
:
cur_exc
=
e
raise
cur_exc
else
:
raise
JsonValueError
(
f
"Expected
{
inst_type
}
at
{
_path
}
, got
{
raw_data
!
r
}
"
,
inst_type
,
raw_data
,
_path
,
_stage
,
)
elif
(
isinstance
(
inst_type
,
type
)
and
issubclass
(
inst_type
,
tuple
)
and
hasattr
(
inst_type
,
"__annotations__"
)
):
# class MyClass(NamedTuple): ...
if
not
isinstance
(
raw_data
,
dict
):
raise
JsonValueError
(
f
"Expected
{
type_name
}
at
{
_path
}
, got
{
raw_data
!
r
}
"
,
inst_type
,
raw_data
,
_path
,
_stage
,
)
if
getattr
(
inst_type
,
"__dash_keys__"
,
"False"
):
raw_data
=
{
key
.
replace
(
"-"
,
"_"
):
val
for
key
,
val
in
raw_data
.
items
()}
defaults
=
getattr
(
inst_type
,
"_field_defaults"
,
{})
kwargs
=
{
field_name
:
self
.
raw_to_typed
(
raw_data
.
get
(
field_name
,
defaults
.
get
(
field_name
,
_missing_value
)),
field_type
,
allow_imports
,
f
"
{
_path
}
->
{
type_name
}
:
{
field_name
}
"
,
_stage
+
(
idx
,),
)
for
idx
,
(
field_name
,
field_type
)
in
enumerate
(
inst_type
.
__annotations__
.
items
())
}
if
self
.
strict
and
not
set
(
raw_data
).
issubset
(
inst_type
.
__annotations__
):
raise
JsonValueError
(
f
"Additional attributes for
{
type_name
}
at
{
_path
}
, got
{
raw_data
!
r
}
"
,
inst_type
,
raw_data
,
_path
,
_stage
,
)
try
:
return
inst_type
(
**
kwargs
)
except
BaseException
:
raise
JsonValueError
(
f
"Expected
{
type_name
}
at
{
_path
}
, got
{
raw_data
!
r
}
"
,
inst_type
,
raw_data
,
_path
,
_stage
,
)
elif
dataclasses
.
is_dataclass
(
inst_type
):
# dataclass
if
not
isinstance
(
raw_data
,
dict
):
raise
JsonValueError
(
f
"Expected
{
type_name
}
at
{
_path
}
, got
{
raw_data
!
r
}
"
,
inst_type
,
raw_data
,
_path
,
_stage
,
)
kwargs
=
{
field
.
name
:
self
.
raw_to_typed
(
raw_data
.
get
(
field
.
name
,
(
(
_missing_value
if
field
.
default_factory
is
dataclasses
.
MISSING
else
field
.
default_factory
()
)
if
field
.
default
is
dataclasses
.
MISSING
else
field
.
default
),
),
field
.
type
,
allow_imports
,
f
"
{
_path
}
->
{
type_name
}
:
{
field
.
name
}
"
,
_stage
+
(
idx
,),
)
for
idx
,
field
in
enumerate
(
dataclasses
.
fields
(
inst_type
))
if
field
.
init
}
if
self
.
strict
and
not
set
(
raw_data
).
issubset
(
field
.
name
for
field
in
dataclasses
.
fields
(
inst_type
)
if
field
.
init
):
raise
JsonValueError
(
f
"Additional attributes for
{
type_name
}
at
{
_path
}
, got
{
raw_data
!
r
}
"
,
inst_type
,
raw_data
,
_path
,
_stage
,
)
try
:
return
inst_type
(
**
kwargs
)
except
BaseException
:
raise
JsonValueError
(
f
"Expected
{
type_name
}
at
{
_path
}
, got
{
raw_data
!
r
}
"
,
inst_type
,
raw_data
,
_path
,
_stage
,
)
elif
typing
.
get_origin
(
inst_type
)
is
list
:
# List[inner_type]
(
inner_type
,)
=
typing
.
get_args
(
inst_type
)
if
not
isinstance
(
raw_data
,
list
):
raise
JsonValueError
(
f
"Expected
{
type_name
}
at
{
_path
}
, got
{
raw_data
!
r
}
"
,
inst_type
,
raw_data
,
_path
,
_stage
,
)
return
[
self
.
raw_to_typed
(
val
,
inner_type
,
allow_imports
,
f
"
{
_path
}
->
{
idx
}
"
,
_stage
+
(
idx
,)
)
for
idx
,
val
in
enumerate
(
raw_data
)
]
elif
typing
.
get_origin
(
inst_type
)
is
set
:
# Set[inner_type]
(
inner_type
,)
=
typing
.
get_args
(
inst_type
)
if
not
isinstance
(
raw_data
,
list
):
raise
JsonValueError
(
f
"Expected
{
type_name
}
at
{
_path
}
, got
{
raw_data
!
r
}
"
,
inst_type
,
raw_data
,
_path
,
_stage
,
)
res
=
set
(
self
.
raw_to_typed
(
val
,
inner_type
,
allow_imports
,
f
"
{
_path
}
->
{
idx
}
"
,
_stage
+
(
idx
,)
)
for
idx
,
val
in
enumerate
(
raw_data
)
)
if
len
(
res
)
!=
len
(
raw_data
):
raise
JsonValueError
(
f
"Duplicate element at
{
_path
}
"
,
inst_type
,
raw_data
,
_path
,
_stage
,
)
return
res
elif
typing
.
get_origin
(
inst_type
)
is
tuple
:
# Tuple[inner_types[0], inner_types[1], ...] or Tuple[inner_types[0], Ellipsis/...]
inner_types
=
typing
.
get_args
(
inst_type
)
if
not
isinstance
(
raw_data
,
list
):
raise
JsonValueError
(
f
"Expected
{
type_name
}
at
{
_path
}
, got
{
raw_data
!
r
}
"
,
inst_type
,
raw_data
,
_path
,
_stage
,
)
if
len
(
inner_types
)
==
2
and
inner_types
[
1
]
is
Ellipsis
:
# Tuple of arbitrary length, all elements same type
# Tuple[inner_types[0], Ellipsis/...]
return
tuple
(
self
.
raw_to_typed
(
val
,
inner_types
[
0
],
allow_imports
,
f
"
{
_path
}
->
{
idx
}
"
,
_stage
+
(
idx
,)
)
for
idx
,
val
in
enumerate
(
raw_data
)
)
else
:
# Fixed size/typed tuple
# Tuple[inner_types[0], inner_types[1], ...]
if
len
(
raw_data
)
!=
len
(
inner_types
):
raise
JsonValueError
(
f
"Expected
{
type_name
}
at
{
_path
}
, got
{
raw_data
!
r
}
"
,
inst_type
,
raw_data
,
_path
,
_stage
,
)
return
[
self
.
raw_to_typed
(
val
,
inner_type
,
allow_imports
,
f
"
{
_path
}
->
{
idx
}
"
,
_stage
+
(
idx
,)
)
for
idx
,
(
val
,
inner_type
)
in
enumerate
(
zip
(
raw_data
,
inner_types
))
]
elif
typing
.
get_origin
(
inst_type
)
is
dict
:
# Dict[str, value_type]
key_type
,
value_type
=
typing
.
get_args
(
inst_type
)
assert
key_type
is
str
if
not
isinstance
(
raw_data
,
dict
):
raise
JsonValueError
(
f
"Expected
{
type_name
}
at
{
_path
}
, got
{
raw_data
!
r
}
"
,
inst_type
,
raw_data
,
_path
,
_stage
,
)
return
{
key
:
self
.
raw_to_typed
(
val
,
value_type
,
allow_imports
,
f
"
{
_path
}
->
{
key
!
r
}
"
,
_stage
+
(
idx
,)
)
for
idx
,
(
key
,
val
)
in
enumerate
(
raw_data
.
items
())
}
elif
inst_type
in
(
dict
,
list
):
# dict, list (no subtyping)
if
not
isinstance
(
raw_data
,
inst_type
):
raise
JsonValueError
(
f
"Expected
{
type_name
}
at
{
_path
}
, got
{
raw_data
!
r
}
"
,
inst_type
,
raw_data
,
_path
,
_stage
,
)
return
raw_data
elif
inst_type
is
EPath
:
if
isinstance
(
raw_data
,
str
):
return
EPath
(
raw_data
)
elif
not
isinstance
(
raw_data
,
EPath
):
raise
JsonValueError
(
f
"Expected
{
type_name
}
at
{
_path
}
, got
{
raw_data
!
r
}
"
,
inst_type
,
raw_data
,
_path
,
_stage
,
)
return
raw_data
elif
(
allow_imports
and
isinstance
(
raw_data
,
dict
)
and
"__module__"
in
raw_data
and
(
"__class__"
in
raw_data
or
"__function__"
in
raw_data
)
):
return
self
.
raw_to_instance
(
raw_data
,
inst_type
,
_path
=
_path
,
_stage
=
_stage
)
else
:
return
raw_data
def
safe_call_function
(
self
,
raw_data
:
Union
[
dict
,
list
,
str
,
int
,
bool
,
float
,
None
],
fn
:
Callable
[...,
TType
],
allow_imports
:
bool
=
False
,
)
->
TType
:
"""
Converts raw data (i.e. dicts, lists and primitives) to typed call arguments.
Validates that python typing matches.
Usage::
def fn(arg1: float, arg2: MyType, arg3) -> Any:
assert isinstance(arg1, float)
assert isinstance(arg2, MyType)
fn(3.141, MyType(), None)
Args:
raw_data: The raw (e.g. json) data to be made as `inst_type`
fn: The function to call with the converted data
strict: If true, don't allow additional attributes
allow_imports: If true, allow instantiating objects by specifying __module__ and __class__/__function__.
Returns:
The return value of `fn`
"""
parameters
=
list
(
inspect
.
signature
(
fn
).
parameters
.
items
())
if
inspect
.
isclass
(
fn
):
init_sig
=
getattr
(
fn
,
"__init__"
,
None
)
if
init_sig
is
not
None
:
parameters
=
list
(
inspect
.
signature
(
init_sig
).
parameters
.
items
())[
1
:]
args
=
[]
kwargs
=
{}
if
isinstance
(
raw_data
,
dict
):
unused_args
=
raw_data
.
copy
()
for
idx
,
(
key
,
param
)
in
enumerate
(
parameters
):
t
=
Any
if
param
.
annotation
is
inspect
.
Parameter
.
empty
else
param
.
annotation
if
param
.
kind
in
(
inspect
.
Parameter
.
POSITIONAL_OR_KEYWORD
,
inspect
.
Parameter
.
KEYWORD_ONLY
,
):
if
param
.
default
is
inspect
.
Parameter
.
empty
and
key
not
in
unused_args
:
raise
ValueError
(
f
"Missing required argument
{
key
!
r
}
for
{
fn
}
"
)
kwargs
[
key
]
=
self
.
raw_to_typed
(
unused_args
.
pop
(
key
,
param
.
default
),
t
,
allow_imports
,
_path
=
key
,
_stage
=
(
idx
,),
)
elif
param
.
kind
==
inspect
.
Parameter
.
VAR_KEYWORD
:
for
arg_key
,
arg_val
in
unused_args
.
items
():
kwargs
[
arg_key
]
=
self
.
raw_to_typed
(
arg_val
,
t
,
allow_imports
,
_path
=
key
,
_stage
=
(
idx
,)
)
unused_args
.
clear
()
elif
param
.
kind
==
inspect
.
Parameter
.
VAR_POSITIONAL
:
# No way to pass positional arguments
pass
elif
param
.
kind
==
inspect
.
Parameter
.
POSITIONAL_ONLY
:
# No way to pass positional arguments
raise
RuntimeError
(
f
"Unsupported positional only argument
{
key
!
r
}
"
)
else
:
raise
RuntimeError
(
f
"Unknown parameter kind
{
param
.
kind
!
r
}
"
)
if
self
.
strict
and
len
(
unused_args
)
>
0
:
raise
ValueError
(
f
"Unexpected arguments:
{
unused_args
!
r
}
"
)
elif
isinstance
(
raw_data
,
list
):
unused_args
=
raw_data
.
copy
()
for
idx
,
(
key
,
param
)
in
enumerate
(
parameters
):
t
=
Any
if
param
.
annotation
is
inspect
.
Parameter
.
empty
else
param
.
annotation
if
param
.
kind
==
inspect
.
Parameter
.
POSITIONAL_ONLY
:
if
param
.
default
is
inspect
.
Parameter
.
empty
and
len
(
unused_args
)
==
0
:
raise
ValueError
(
f
"Missing required positional-only argument
{
key
!
r
}
at index
{
idx
}
"
)
args
.
append
(
self
.
raw_to_typed
(
unused_args
.
pop
(),
t
,
allow_imports
,
_path
=
key
,
_stage
=
(
idx
,)
)
)
elif
param
.
kind
==
inspect
.
Parameter
.
POSITIONAL_OR_KEYWORD
:
if
param
.
default
is
inspect
.
Parameter
.
empty
and
len
(
unused_args
)
==
0
:
raise
ValueError
(
f
"Missing required positional argument
{
key
!
r
}
at index
{
idx
}
"
)
if
len
(
unused_args
)
==
0
:
arg_val
=
param
.
default
else
:
arg_val
=
unused_args
.
pop
()
args
.
append
(
self
.
raw_to_typed
(
arg_val
,
t
,
allow_imports
,
_path
=
key
,
_stage
=
(
idx
,))
)
elif
param
.
kind
==
inspect
.
Parameter
.
VAR_POSITIONAL
:
for
arg_val
in
unused_args
:
args
.
append
(
self
.
raw_to_typed
(
arg_val
,
t
,
allow_imports
,
_path
=
key
,
_stage
=
(
idx
,))
)
unused_args
.
clear
()
elif
param
.
kind
==
inspect
.
Parameter
.
VAR_KEYWORD
:
# No way to pass keyword arguments
pass
elif
param
.
kind
==
inspect
.
Parameter
.
KEYWORD_ONLY
:
raise
RuntimeError
(
f
"Unsupported keyword-only argument
{
key
!
r
}
"
)
else
:
raise
RuntimeError
(
f
"Unknown parameter kind
{
param
.
kind
!
r
}
"
)
if
self
.
strict
and
len
(
unused_args
)
>
0
:
raise
ValueError
(
f
"Unexpected arguments:
{
unused_args
!
r
}
"
)
else
:
raise
ValueError
(
f
"Cannot call function with raw data of type
{
type
(
raw_data
)
!
r
}
, require list or dict"
)
return
fn
(
*
args
,
**
kwargs
)
def
override
(
# noqa: C901
self
,
value
:
TType
,
overrides
:
Any
,
inst_type
:
Optional
[
Type
[
TType
]]
=
None
,
allow_imports
:
bool
=
False
,
_path
:
str
=
"root"
,
_stage
:
Tuple
[
int
,
...]
=
(),
)
->
TType
:
"""
Allows overriding values of a typed object using environment config.
Allows overriding single config variables, or whole objects.
Examples::
class MyNamedTuple(NamedTuple):
x: int
y: str
class MyNested(NamedTuple):
nested: MyNamedTuple
assert override(
MyNested(nested=MyNamedTuple(x=42, y="foo")),
{'nested.x': 5},
) == MyNested(nested=MyNamedTuple(x=5, y="foo"))
assert override(
MyNested(nested=MyNamedTuple(x=42, y="foo")),
{'nested': '{"x": 5, "y": "bar"}'},
) == MyNested(nested=MyNamedTuple(x=5, y="bar"))
Args:
value: The base value to override.
overrides: The overrides to apply
strict: If true, no additional keys are allowed
inst_type: If given, validate against this base type instead of the type of `value`.
allow_imports: If true, allow instantiating types with dicts of __module__ and __class__/__function__.
_path: Internal: The path to the current value.
_stage: Internal: The current stage of the override.
Returns:
Same type as the input object (or `inst_type` if set), copied and updated from the
overrides.
"""
if
inst_type
is
None
:
inst_type
=
type
(
value
)
type_name
=
getattr
(
inst_type
,
"__name__"
,
repr
(
inst_type
))
if
inst_type
in
(
str
,
int
,
float
,
bool
,
None
,
type
(
None
)):
# Literal types
if
inst_type
in
(
None
,
type
(
None
))
and
overrides
==
"None"
:
overrides
=
None
elif
inst_type
is
bool
and
overrides
in
(
"True"
,
"true"
,
"1"
,
"False"
,
"false"
,
"0"
):
overrides
=
overrides
in
(
"True"
,
"true"
,
"1"
)
elif
inst_type
in
(
int
,
float
)
and
isinstance
(
overrides
,
str
):
overrides
=
inst_type
(
overrides
)
if
not
isinstance
(
overrides
,
inst_type
)
and
not
(
isinstance
(
overrides
,
int
)
and
inst_type
is
float
):
raise
JsonValueError
(
f
"Type does not match, expected
{
type_name
}
at
{
_path
}
, got
{
overrides
!
r
}
"
,
inst_type
,
overrides
,
_path
,
_stage
,
)
return
overrides
elif
inst_type
is
Any
:
# Any
if
isinstance
(
overrides
,
str
):
if
overrides
.
isnumeric
():
return
int
(
overrides
)
elif
overrides
==
"True"
:
return
True
elif
overrides
==
"False"
:
return
True
return
overrides
if
isinstance
(
value
,
(
dict
,
list
,
tuple
)):
# Merge with dict, list, str
return
self
.
override
(
value
,
overrides
,
type
(
value
),
allow_imports
,
_path
,
_stage
)
raise
JsonValueError
(
f
"Expected
{
type_name
}
at
{
_path
}
, got
{
overrides
!
r
}
"
,
inst_type
,
overrides
,
_path
,
_stage
,
)
elif
typing
.
get_origin
(
inst_type
)
is
Literal
:
# Literal[value]
(
value
,)
=
typing
.
get_args
(
inst_type
)
if
value
!=
overrides
:
raise
JsonValueError
(
f
"Expected
{
type_name
}
at
{
_path
}
, got
{
overrides
!
r
}
"
,
inst_type
,
overrides
,
_path
,
_stage
,
)
return
value
elif
typing
.
get_origin
(
inst_type
)
is
Union
:
# Union[union_types[0], union_types[1], ...]
union_types
=
typing
.
get_args
(
inst_type
)
if
isinstance
(
overrides
,
str
):
for
subtype
in
union_types
:
if
subtype
is
None
and
overrides
==
"None"
:
return
None
elif
subtype
is
bool
:
if
overrides
==
"True"
:
return
True
elif
overrides
==
"False"
:
return
False
elif
subtype
is
int
and
overrides
.
strip
().
isnumeric
():
return
int
(
overrides
)
elif
subtype
is
str
:
return
overrides
elif
subtype
is
float
and
float_pattern
.
fullmatch
(
overrides
):
return
float
(
overrides
)
if
overrides
.
lstrip
().
startswith
(
"{"
)
or
overrides
.
lstrip
().
startswith
(
"["
):
overrides
=
json
.
loads
(
overrides
)
return
self
.
raw_to_typed
(
overrides
,
inst_type
,
allow_imports
,
_path
,
_stage
,
)
for
subtype
in
union_types
:
if
_isinstance_deep
(
value
,
subtype
):
return
self
.
override
(
value
,
overrides
,
subtype
,
allow_imports
,
f
"
{
_path
}
->
{
getattr
(
subtype
,
'__name__'
,
repr
(
subtype
))
}
"
,
_stage
+
(
1
,),
)
raise
JsonValueError
(
f
"Expected
{
type_name
}
at
{
_path
}
, existing is
{
value
!
r
}
which is invalid"
,
inst_type
,
value
,
_path
,
_stage
,
)
elif
(
isinstance
(
inst_type
,
type
)
and
issubclass
(
inst_type
,
tuple
)
and
hasattr
(
inst_type
,
"__annotations__"
)
):
# class MyClass(NamedTuple): ...
if
not
isinstance
(
overrides
,
(
dict
,
str
)):
raise
JsonValueError
(
f
"Expected
{
type_name
}
at
{
_path
}
, got
{
overrides
!
r
}
"
,
inst_type
,
overrides
,
_path
,
_stage
,
)
if
isinstance
(
overrides
,
str
):
return
self
.
raw_to_typed
(
json
.
loads
(
overrides
),
inst_type
,
allow_imports
,
_path
,
_stage
,
)
local_overrides
=
_split_dict_keys
(
overrides
)
if
getattr
(
inst_type
,
"__dash_keys__"
,
"False"
):
local_overrides
=
{
key
.
replace
(
"-"
,
"_"
):
val
for
key
,
val
in
local_overrides
.
items
()
}
kwargs
=
{
field_name
:
(
self
.
override
(
getattr
(
value
,
field_name
),
local_overrides
.
pop
(
field_name
),
field_type
,
allow_imports
,
f
"
{
_path
}
->
{
type_name
}
:
{
field_name
}
"
,
_stage
+
(
idx
,),
)
if
field_name
in
local_overrides
else
getattr
(
value
,
field_name
)
)
for
idx
,
(
field_name
,
field_type
)
in
enumerate
(
inst_type
.
__annotations__
.
items
())
}
if
self
.
strict
and
len
(
local_overrides
)
!=
0
:
raise
JsonValueError
(
f
"Invalid config keys
{
', '
.
join
(
local_overrides
.
keys
())
}
for
{
type_name
}
at "
f
"
{
_path
}
"
,
inst_type
,
overrides
,
_path
,
_stage
,
)
try
:
return
inst_type
(
**
kwargs
)
except
BaseException
:
raise
JsonValueError
(
f
"Expected
{
type_name
}
at
{
_path
}
, got
{
overrides
!
r
}
"
,
inst_type
,
overrides
,
_path
,
_stage
,
)
elif
dataclasses
.
is_dataclass
(
inst_type
):
# dataclass
if
not
isinstance
(
overrides
,
(
dict
,
str
)):
raise
JsonValueError
(
f
"Expected
{
type_name
}
at
{
_path
}
, got
{
overrides
!
r
}
"
,
inst_type
,
overrides
,
_path
,
_stage
,
)
if
isinstance
(
overrides
,
str
):
return
self
.
raw_to_typed
(
json
.
loads
(
overrides
),
inst_type
,
allow_imports
,
_path
,
_stage
,
)
local_overrides
=
_split_dict_keys
(
overrides
)
if
getattr
(
inst_type
,
"__dash_keys__"
,
"False"
):
local_overrides
=
{
key
.
replace
(
"-"
,
"_"
):
val
for
key
,
val
in
local_overrides
.
items
()
}
kwargs
=
{
field
.
name
:
(
self
.
override
(
getattr
(
value
,
field
.
name
),
local_overrides
.
pop
(
field
.
name
),
field
.
type
,
allow_imports
,
f
"
{
_path
}
->
{
type_name
}
:
{
field
.
name
}
"
,
_stage
+
(
idx
,),
)
if
field
.
name
in
local_overrides
else
getattr
(
value
,
field
.
name
)
)
for
idx
,
field
in
enumerate
(
dataclasses
.
fields
(
inst_type
))
if
field
.
init
}
if
self
.
strict
and
len
(
local_overrides
)
!=
0
:
raise
JsonValueError
(
f
"Invalid config keys
{
', '
.
join
(
local_overrides
.
keys
())
}
for
{
type_name
}
at "
f
"
{
_path
}
"
,
inst_type
,
overrides
,
_path
,
_stage
,
)
try
:
return
inst_type
(
**
kwargs
)
except
BaseException
:
raise
JsonValueError
(
f
"Expected
{
type_name
}
at
{
_path
}
, got
{
overrides
!
r
}
"
,
inst_type
,
overrides
,
_path
,
_stage
,
)
elif
(
typing
.
get_origin
(
inst_type
)
is
list
or
typing
.
get_origin
(
inst_type
)
is
tuple
or
inst_type
in
(
list
,
tuple
)
):
# List[inner_type] or Tuple[inner_type, Ellipsis] or
# Tuple[inner_type[0], inner_type[1], ...]
if
inst_type
is
list
:
inner_type
=
Any
inner_types
=
[]
cls
=
list
elif
inst_type
is
tuple
:
inner_type
=
Any
inner_types
=
[]
cls
=
tuple
elif
typing
.
get_origin
(
inst_type
)
is
list
:
(
inner_type
,)
=
typing
.
get_args
(
inst_type
)
inner_types
=
[]
cls
=
list
else
:
inner_types
=
typing
.
get_args
(
inst_type
)
if
len
(
inner_types
)
==
2
and
inner_types
[
1
]
is
Ellipsis
:
inner_type
=
inner_types
[
0
]
else
:
inner_type
=
None
cls
=
tuple
if
not
isinstance
(
overrides
,
(
dict
,
str
)):
raise
JsonValueError
(
f
"Expected
{
type_name
}
at
{
_path
}
, got
{
overrides
!
r
}
"
,
inst_type
,
overrides
,
_path
,
_stage
,
)
if
isinstance
(
overrides
,
str
):
return
self
.
raw_to_typed
(
json
.
loads
(
overrides
),
inst_type
,
allow_imports
,
_path
,
_stage
,
)
local_overrides
=
_split_dict_keys
(
overrides
)
if
not
all
(
key
.
isnumeric
()
for
key
in
local_overrides
.
keys
()):
raise
JsonValueError
(
f
"Expected
{
type_name
}
at
{
_path
}
, got
{
overrides
!
r
}
, expected integer keys"
,
inst_type
,
overrides
,
_path
,
_stage
,
)
local_overrides_int
=
{
int
(
key
):
value
for
key
,
value
in
local_overrides
.
items
()}
new_max_idx
=
max
(
local_overrides_int
.
keys
())
original_max_idx
=
len
(
value
)
if
inner_type
is
None
and
new_max_idx
>=
len
(
inner_types
):
raise
JsonValueError
(
f
"Expected
{
type_name
}
at
{
_path
}
, got
{
overrides
!
r
}
, index
{
new_max_idx
}
out of "
f
"bounds"
,
inst_type
,
overrides
,
_path
,
_stage
,
)
for
i
in
range
(
original_max_idx
,
new_max_idx
):
if
i
not
in
local_overrides_int
:
raise
JsonValueError
(
f
"Expected
{
type_name
}
at
{
_path
}
, got
{
overrides
!
r
}
, missing value for index "
f
"
{
i
}
"
,
inst_type
,
overrides
,
_path
,
_stage
,
)
return
cls
(
(
self
.
override
(
value
[
idx
],
local_overrides_int
[
idx
],
inner_type
,
allow_imports
,
f
"
{
_path
}
->
{
idx
}
"
,
_stage
+
(
idx
,),
)
if
idx
in
local_overrides_int
else
value
[
idx
]
)
for
idx
in
range
(
max
(
new_max_idx
+
1
,
original_max_idx
))
)
elif
typing
.
get_origin
(
inst_type
)
is
dict
or
inst_type
is
dict
:
# Dict[str, value_type]
if
inst_type
is
dict
:
value_type
=
Any
else
:
key_type
,
value_type
=
typing
.
get_args
(
inst_type
)
assert
key_type
is
str
if
not
isinstance
(
overrides
,
(
dict
,
str
)):
raise
JsonValueError
(
f
"Expected
{
type_name
}
at
{
_path
}
, got
{
overrides
!
r
}
"
,
inst_type
,
overrides
,
_path
,
_stage
,
)
if
isinstance
(
overrides
,
str
):
return
self
.
raw_to_typed
(
json
.
loads
(
overrides
),
inst_type
,
allow_imports
,
_path
,
_stage
,
)
local_overrides
=
_split_dict_keys
(
overrides
)
if
getattr
(
inst_type
,
"__dash_keys__"
,
"False"
):
local_overrides
=
{
key
.
replace
(
"-"
,
"_"
):
val
for
key
,
val
in
local_overrides
.
items
()
}
res
=
{
key
:
(
self
.
override
(
subvalue
,
local_overrides
.
pop
(
key
),
value_type
,
allow_imports
,
f
"
{
_path
}
->
{
type_name
}
:
{
key
!
r
}
"
,
_stage
+
(
idx
,),
)
if
key
in
local_overrides
else
subvalue
)
for
idx
,
(
key
,
subvalue
)
in
value
.
items
()
}
for
key
,
val
in
local_overrides
.
items
():
if
not
isinstance
(
val
,
str
):
raise
JsonValueError
(
f
"Expected new
{
type_name
}
at
{
_path
}
->
{
type_name
}
:
{
key
!
r
}
, got
{
val
!
r
}
"
,
inst_type
,
overrides
,
_path
,
_stage
,
)
res
[
key
]
=
self
.
raw_to_typed
(
json
.
loads
(
val
),
value_type
,
allow_imports
,
f
"
{
_path
}
->
{
type_name
}
:
{
key
!
r
}
"
,
_stage
+
(
len
(
res
),),
)
return
res
else
:
raise
RuntimeError
(
f
"Unknown type
{
inst_type
}
"
)
def
to_json_object
(
obj
:
Any
)
->
Any
:
"""
Converts the given object to a json object.
Args:
obj: The object to convert
Returns:
The json-like object.
"""
if
isinstance
(
obj
,
(
str
,
int
,
float
,
bool
,
type
(
None
))):
# Literal types
return
obj
elif
isinstance
(
obj
,
tuple
)
and
hasattr
(
obj
,
"__annotations__"
):
# class MyClass(NamedTuple): ...
return
{
field_name
:
to_json_object
(
getattr
(
obj
,
field_name
))
for
field_name
in
obj
.
__annotations__
.
keys
()
}
elif
dataclasses
.
is_dataclass
(
obj
):
# dataclass
return
{
field
.
name
:
to_json_object
(
getattr
(
obj
,
field
.
name
))
for
field
in
dataclasses
.
fields
(
obj
)
if
field
.
init
}
elif
isinstance
(
obj
,
(
list
,
tuple
)):
return
[
to_json_object
(
val
)
for
val
in
obj
]
elif
isinstance
(
obj
,
dict
):
return
{
key
:
to_json_object
(
val
)
for
key
,
val
in
obj
.
items
()}
else
:
raise
RuntimeError
(
f
"Unknown type
{
type
(
obj
)
}
"
)
float_pattern
=
re
.
compile
(
r
"[+-]?(\d+(\.\d*)?|\.\d+)([eE][+-]?\d+)?"
)
def
_split_dict_keys
(
dct
:
Dict
[
str
,
Any
])
->
Dict
[
str
,
Any
]:
"""Splits the given dict keys by first '.' to subdicts."""
res
=
{}
for
key
,
value
in
dct
.
items
():
if
"."
in
key
:
outer_key
,
_
,
inner_key
=
key
.
partition
(
"."
)
if
outer_key
in
res
:
if
not
isinstance
(
res
[
outer_key
],
dict
):
raise
ValueError
(
f
"Cannot combine
{
outer_key
!
r
}
with
{
res
!
r
}
"
)
res
[
outer_key
][
inner_key
]
=
value
else
:
res
[
outer_key
]
=
{
inner_key
:
value
}
else
:
if
key
in
res
:
raise
ValueError
(
f
"Cannot combine
{
key
!
r
}
with
{
res
!
r
}
"
)
res
[
key
]
=
value
return
res
def
_isinstance_deep
(
val
:
Any
,
tp_chk
:
Type
)
->
bool
:
"""Verifies if the given value is an instance of the tp_chk, allowing for typing extensions."""
if
tp_chk
is
Any
:
return
True
elif
typing
.
get_origin
(
tp_chk
)
is
Literal
:
(
value
,)
=
typing
.
get_args
(
val
)
return
val
==
value
elif
typing
.
get_origin
(
tp_chk
)
is
list
:
(
inner_type
,)
=
typing
.
get_args
(
val
)
return
isinstance
(
val
,
list
)
and
all
(
_isinstance_deep
(
v
,
inner_type
)
for
v
in
val
)
elif
typing
.
get_origin
(
tp_chk
)
is
tuple
:
inner_types
=
typing
.
get_args
(
val
)
if
len
(
inner_types
)
==
2
and
inner_types
[
1
]
==
Ellipsis
:
return
isinstance
(
val
,
tuple
)
and
all
(
_isinstance_deep
(
v
,
inner_types
[
0
])
for
v
in
val
)
else
:
return
(
isinstance
(
val
,
tuple
)
and
len
(
val
)
==
len
(
inner_types
)
and
all
(
_isinstance_deep
(
v
,
inner_type
)
for
v
,
inner_type
in
zip
(
val
,
inner_types
))
)
elif
typing
.
get_origin
(
tp_chk
)
is
dict
:
key_type
,
value_type
=
typing
.
get_args
(
val
)
return
isinstance
(
val
,
dict
)
and
all
(
_isinstance_deep
(
k
,
key_type
)
and
_isinstance_deep
(
v
,
value_type
)
for
k
,
v
in
val
.
items
()
)
else
:
return
isinstance
(
val
,
tp_chk
)
Megatron-Energon/src/megatron/energon/watchdog.py
0 → 100644
View file @
f356f546
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
import
inspect
import
linecache
import
os
import
sys
import
threading
import
time
import
traceback
from
time
import
perf_counter
from
typing
import
Any
,
Callable
,
Iterable
,
Iterator
,
Optional
,
TypeVar
import
torch
from
torch.distributed._shard.sharded_tensor
import
ShardedTensorBase
# For the watch_iter type
T
=
TypeVar
(
"T"
)
# Maximum length of a single object string to print.
PRINT_LOCAL_MAX_LENGTH
=
250
class
Watchdog
:
"""
A watchdog timer that:
- can be 'enabled' or 'disabled' by presence/absence of a deadline,
- resets automatically when 'enable()' is called,
- can be used as a context manager,
- can wrap an iterator to watch only the time for 'next()' calls,
- attempts a two-phase shutdown on callback error:
1) sys.exit(1) for graceful,
2) if still alive after 10s, os._exit(1).
"""
def
__init__
(
self
,
timeout
:
float
,
initial_timeout
:
Optional
[
float
]
=
None
,
callback
:
Optional
[
Callable
[[],
None
]]
=
None
,
dump_stacks
:
bool
=
True
,
enabled
:
bool
=
True
,
)
->
None
:
"""
Args:
timeout: Number of seconds before the watchdog fires if not reset/disabled.
initial_timeout: Number of seconds before the watchdog fires in the first iteration.
callback: Optional function to call upon timeout.
dump_stacks: If True, print full stack traces for all threads on timeout (except watchdog's own thread).
enabled: If False, watchdog starts disabled until enable() is called.
"""
self
.
_timeout
=
timeout
self
.
_initial_timeout
=
initial_timeout
self
.
_callback
=
callback
self
.
_dump_stacks
=
dump_stacks
self
.
_is_first_iteration
=
True
# If _deadline is None, the watchdog is disabled.
# Otherwise, _deadline = time.time() + _timeout if enabled.
if
enabled
:
self
.
_deadline
:
Optional
[
float
]
=
perf_counter
()
+
self
.
_get_next_timeout
()
else
:
self
.
_deadline
=
None
self
.
_stop
=
False
# signals permanent shutdown (finish)
# Condition variable to manage state changes
self
.
_cv
=
threading
.
Condition
()
# Background thread (daemon) that monitors timeouts
self
.
_worker_thread
=
threading
.
Thread
(
target
=
self
.
_worker
,
daemon
=
True
)
self
.
_worker_thread
.
start
()
def
_get_next_timeout
(
self
)
->
float
:
if
self
.
_is_first_iteration
:
self
.
_is_first_iteration
=
False
return
self
.
_initial_timeout
if
self
.
_initial_timeout
is
not
None
else
self
.
_timeout
else
:
return
self
.
_timeout
def
_worker
(
self
)
->
None
:
"""
Background thread that periodically checks if the watchdog has expired.
Once it times out or is told to stop, it exits.
"""
while
True
:
with
self
.
_cv
:
if
self
.
_stop
:
# finish() was called; end the worker.
return
if
self
.
_deadline
is
None
:
# Disabled; no deadline. Just wait a bit, then re-check.
self
.
_cv
.
wait
(
timeout
=
1.0
)
continue
remaining
=
self
.
_deadline
-
perf_counter
()
if
remaining
<=
0
:
# We have timed out
self
.
_on_timeout
()
return
else
:
# Wait until either the deadline or a state change
self
.
_cv
.
wait
(
timeout
=
remaining
)
def
_on_timeout
(
self
)
->
None
:
"""
Called exactly once if the watchdog times out.
1) Optionally dumps stacks,
2) Calls user callback,
3) If callback raises an error,
- print traceback,
- sys.exit(1),
- fallback to os._exit(1) after 10s if process not terminated.
"""
watchdog_thread_id
=
threading
.
get_ident
()
# 1) Dump stacks if requested
if
self
.
_dump_stacks
:
print
(
"Watchdog triggered: Dumping thread stacks"
)
self
.
_print_all_thread_stacks
(
skip_thread_id
=
watchdog_thread_id
)
# 2) Call user callback
if
self
.
_callback
:
try
:
self
.
_callback
()
except
Exception
:
# Print the traceback
traceback
.
print_exc
()
# Start a background kill-switch after 10 seconds
def
force_exit_after_delay
()
->
None
:
time
.
sleep
(
10
)
os
.
_exit
(
1
)
killer
=
threading
.
Thread
(
target
=
force_exit_after_delay
,
daemon
=
True
)
killer
.
start
()
# Attempt graceful shutdown
sys
.
exit
(
1
)
def
_print_all_thread_stacks
(
self
,
skip_thread_id
:
Optional
[
int
]
=
None
)
->
None
:
"""
Dump stacks of all threads in a style reminiscent of py-spy, from
innermost (current) to outermost. Skip the watchdog's own thread if given.
Args:
skip_thread_id: If given, skip this thread's stack.
"""
frames
=
sys
.
_current_frames
()
# thread_id -> frame
# We gather known threads to print their names
all_threads
=
{
t
.
ident
:
t
for
t
in
threading
.
enumerate
()}
for
thread_id
,
frame
in
frames
.
items
():
if
skip_thread_id
is
not
None
and
thread_id
==
skip_thread_id
:
continue
thread
=
all_threads
.
get
(
thread_id
)
thread_name
=
thread
.
name
if
thread
else
f
"Unknown-
{
thread_id
}
"
print
(
f
'Thread
{
thread_id
}
: "
{
thread_name
}
"'
)
# Build the stack from current (innermost) to outermost
stack_frames
=
[]
f
=
frame
while
f
is
not
None
:
stack_frames
.
append
(
f
)
f
=
f
.
f_back
for
fr
in
stack_frames
:
code
=
fr
.
f_code
func_name
=
code
.
co_name
filename
=
code
.
co_filename
lineno
=
fr
.
f_lineno
print
(
f
"
{
func_name
}
(
{
filename
}
:
{
lineno
}
)"
)
# Attempt to read the actual line of source
line
=
linecache
.
getline
(
filename
,
lineno
).
rstrip
()
if
line
:
print
(
f
" >
{
line
}
"
)
# Show arguments and locals
arg_info
=
inspect
.
getargvalues
(
fr
)
arg_names
=
arg_info
.
args
varargs
=
arg_info
.
varargs
varkw
=
arg_info
.
keywords
local_vars
=
arg_info
.
locals
# Separate out the arguments
arg_dict
=
{}
for
arg
in
arg_names
:
if
arg
in
local_vars
:
arg_dict
[
arg
]
=
local_vars
[
arg
]
if
varargs
and
varargs
in
local_vars
:
arg_dict
[
"*"
+
varargs
]
=
local_vars
[
varargs
]
if
varkw
and
varkw
in
local_vars
:
arg_dict
[
"**"
+
varkw
]
=
local_vars
[
varkw
]
if
arg_dict
:
print
(
" Arguments:"
)
for
k
,
v
in
arg_dict
.
items
():
print
(
f
"
{
k
}
:
{
repr_short
(
v
)
}
"
)
other_locals
=
{
k
:
v
for
k
,
v
in
local_vars
.
items
()
if
k
not
in
arg_dict
}
if
other_locals
:
print
(
" Locals:"
)
for
k
,
v
in
other_locals
.
items
():
print
(
f
"
{
k
}
:
{
repr_short
(
v
)
}
"
)
print
(
flush
=
True
)
def
reset
(
self
)
->
None
:
"""
Reset the watchdog timer (push out deadline by `timeout` seconds),
but only if currently enabled (i.e., _deadline is not None).
"""
with
self
.
_cv
:
if
self
.
_deadline
is
not
None
:
self
.
_deadline
=
perf_counter
()
+
self
.
_timeout
self
.
_cv
.
notify
()
def
enable
(
self
)
->
None
:
"""
Enable (or re-enable) the watchdog. Always resets the deadline to
`time.time() + timeout`.
"""
with
self
.
_cv
:
self
.
_deadline
=
perf_counter
()
+
self
.
_get_next_timeout
()
self
.
_cv
.
notify
()
def
disable
(
self
)
->
None
:
"""
Disable the watchdog (no timeout will fire until re-enabled).
"""
with
self
.
_cv
:
self
.
_deadline
=
None
self
.
_cv
.
notify
()
def
finish
(
self
)
->
None
:
"""
Permanently stop the watchdog thread and disarm the timer.
After calling finish(), you cannot re-enable this watchdog.
"""
with
self
.
_cv
:
self
.
_stop
=
True
self
.
_cv
.
notify
()
self
.
_worker_thread
.
join
()
def
__enter__
(
self
)
->
"Watchdog"
:
# If currently disabled, calling enable() will also reset the timer
if
self
.
_deadline
is
None
:
self
.
enable
()
return
self
def
__exit__
(
self
,
exc_type
:
Any
,
exc_val
:
Any
,
exc_tb
:
Any
)
->
None
:
# End the watchdog on context exit
self
.
finish
()
def
watch_iter
(
self
,
iterable
:
Iterable
[
T
])
->
Iterator
[
T
]:
"""
Wrap an iterable so that each 'next()' call is watched by the watchdog,
but the time in between iterations is not watched. Usage:
wd = Watchdog(timeout=3, enabled=False)
for item in wd.watch_iter(generator()):
# processing item not timed by the watchdog
pass
This pattern:
- enable() -> sets/extends deadline
- next(...) -> measured portion
- disable() -> stops timer
Args:
iterable: The iterable to wrap and watch.
Returns:
An iterator that wraps the input iterable and watches for timeouts.
"""
try
:
self
.
enable
()
for
item
in
iterable
:
self
.
disable
()
yield
item
self
.
enable
()
finally
:
self
.
disable
()
def
repr_short
(
obj
:
Any
)
->
str
:
"""
Return a short repr of an object.
"""
if
isinstance
(
obj
,
torch
.
Tensor
):
if
isinstance
(
obj
,
ShardedTensorBase
)
or
obj
.
is_cuda
:
return
"<CUDA tensor>"
s
=
repr
(
obj
)
if
len
(
s
)
>
PRINT_LOCAL_MAX_LENGTH
:
s
=
s
[:
PRINT_LOCAL_MAX_LENGTH
//
2
]
+
"..."
+
s
[
-
PRINT_LOCAL_MAX_LENGTH
//
2
:]
return
s
if
__name__
==
"__main__"
:
# Example usage
def
my_callback
()
->
None
:
print
(
"Watchdog timed out in callback."
)
# Demonstrate an error
raise
ValueError
(
"Example error from callback."
)
print
(
"Simple usage example:"
)
wd
=
Watchdog
(
timeout
=
2
,
callback
=
my_callback
,
enabled
=
True
)
print
(
"Sleeping 3s so the watchdog times out."
)
time
.
sleep
(
30
)
# Because we never reset or finish, the watchdog should fire and
# forcibly exit, after printing the traceback and stack dumps.
print
(
"You won't see this line if the watchdog fired first."
)
Megatron-Energon/src/megatron/energon/worker.py
0 → 100644
View file @
f356f546
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
import
hashlib
import
json
import
multiprocessing
from
dataclasses
import
dataclass
from
pathlib
import
Path
from
typing
import
Any
,
ClassVar
,
Dict
,
List
,
Optional
,
TextIO
,
TypeVar
import
torch.distributed
import
torch.utils.data
from
megatron.energon.cache
import
CachePool
__all__
=
(
"WorkerConfig"
,)
T
=
TypeVar
(
"T"
)
@
dataclass
(
slots
=
True
,
kw_only
=
True
,
eq
=
False
)
class
WorkerConfig
:
"""
Provides information about the current worker and the global configuration. This gives each
data parallel rank its proper config. Every `rank` (up to `world_size-1`) must be used.
If set wrong, the datasets might yield the same data or data might be missing, as data
is split over the data parallel ranks with this config!
You may set the same rank, if you need multiple ranks to retrieve the same data.
"""
#: The data parallel rank/id of the current process.
rank
:
int
#: The total number of data parallel processes.
world_size
:
int
#: The number of workers per rank. May be 0 to disable worker processes.
num_workers
:
int
#: If not using all ranks for data parallel, set this to the corresponding group.
data_parallel_group
:
Optional
[
torch
.
distributed
.
ProcessGroup
]
=
None
#: The id offset of the current worker. e.g. the worker may live as `worker_info.id=0`, but
# actually yield samples for id=1 (i.e. worker_id_offset=1). Required to support restoring the
# worker state if last emitted sample was not for worker_id=0. Required by SavableDataLoader to
# restore the worker state. Is only set to nonzero within a worker process.
worker_id_offset
:
ClassVar
[
int
]
=
0
#: The following seed_offset is used used at two points in the code.
# 1. The seed_offset in the worker_config that is passed to the dataset initialization, is used
# to set the seed for the dataset shuffling and shuffled blending (All code that uses WorkerRng).
# 2. The worker_config passed to the data loader initialization, is used to set the seed for the
# torch, numpy and random libraries. This does not affect the dataset shuffling, but only the
# user code (e.g. code in TaskEncoder).
seed_offset
:
int
=
0
#: The path to the debug file for the current worker. Should contain "{worker_id}" and "{pid}"
# to separate the workers.
worker_debug_path
:
Optional
[
str
]
=
None
#: Log level for worker logging.
worker_log_level
:
int
=
0
#: The opened file for the current worker. Should not be set from outside.
_worker_debug_file
:
Optional
[
TextIO
]
=
None
#: worker_id of the opened worker debug file
_worker_debug_file_worker_id
:
Optional
[
int
]
=
None
#: The current sample index within the current iterating worker
_sample_index_stack
:
ClassVar
[
Optional
[
List
[
int
]]]
=
None
#: The current worker config within the current iterating worker
active_worker_config
:
ClassVar
[
Optional
[
"WorkerConfig"
]]
=
None
#: The global rank override for the worker. Required for restoring samples.
_worker_override_global_rank
:
ClassVar
[
Optional
[
List
[
int
]]]
=
None
#: The current cache pool for the worker.
_cache_pool
:
"ClassVar[Optional[CachePool]]"
=
None
def
worker_activate
(
self
,
sample_index
:
int
,
override_global_rank
:
Optional
[
int
]
=
None
,
cache_pool
:
"Optional[CachePool]"
=
None
,
):
"""Activates the worker config for the current worker and sets it as actively iterating.
Must be called before next() call on the datasets."""
assert
WorkerConfig
.
active_worker_config
is
None
WorkerConfig
.
_sample_index_stack
=
[
sample_index
]
WorkerConfig
.
active_worker_config
=
self
WorkerConfig
.
_worker_override_global_rank
=
override_global_rank
WorkerConfig
.
_cache_pool
=
cache_pool
def
worker_push_sample_index
(
self
,
sample_index
:
int
):
"""Pushes a new sample index to the sample index stack. Should be set by wrapping datasets
before calling inners."""
assert
WorkerConfig
.
active_worker_config
is
not
None
WorkerConfig
.
_sample_index_stack
.
append
(
sample_index
)
def
worker_pop_sample_index
(
self
):
"""Pushes a new sample index to the sample index stack. Should be set by wrapping datasets
before calling inners."""
assert
WorkerConfig
.
active_worker_config
is
not
None
return
WorkerConfig
.
_sample_index_stack
.
pop
()
def
worker_deactivate
(
self
):
"""Deactivates the worker config for the current worker and deactivates it for iterating.
Must be called after next() call on the datasets."""
if
WorkerConfig
.
active_worker_config
is
not
None
:
assert
len
(
WorkerConfig
.
_sample_index_stack
)
==
1
,
(
f
"Sample index stack not empty:
{
WorkerConfig
.
_sample_index_stack
}
"
)
WorkerConfig
.
_sample_index_stack
=
None
WorkerConfig
.
active_worker_config
=
None
WorkerConfig
.
_worker_override_global_rank
=
None
@
property
def
active_worker_sample_index
(
self
)
->
int
:
"""Returns the current sample index for the actively iterating worker."""
# Internal sample index is for the local worker. If using multiple workers per rank, this
# must be multiplied by the number of workers and offset by the local worker index.
return
(
WorkerConfig
.
_sample_index_stack
[
-
1
]
*
max
(
self
.
num_workers
,
1
)
+
self
.
rank_worker_id
()
)
@
property
def
active_worker_batch_index
(
self
)
->
int
:
"""Returns the current batch index for the actively iterating worker."""
# Internal batch index is for the local worker. If using multiple workers per rank, this
# must be multiplied by the number of workers and offset by the local worker index.
return
(
WorkerConfig
.
_sample_index_stack
[
0
]
*
max
(
self
.
num_workers
,
1
)
+
self
.
rank_worker_id
()
)
def
global_rank
(
self
)
->
int
:
"""Returns the global rank of this worker config but as a global rank, not
as a rank within the data parallel group."""
if
self
.
data_parallel_group
is
None
:
return
self
.
rank
return
torch
.
distributed
.
get_global_rank
(
self
.
data_parallel_group
,
self
.
rank
)
def
__eq__
(
self
,
other
):
"""Do not compare everything to check for equal config"""
if
not
isinstance
(
other
,
WorkerConfig
):
return
NotImplementedError
()
return
all
(
[
self
.
rank
==
other
.
rank
,
self
.
world_size
==
other
.
world_size
,
self
.
num_workers
==
other
.
num_workers
,
]
)
@
staticmethod
def
default_worker_config
(
num_workers
:
int
=
4
,
data_parallel_group
:
Optional
[
torch
.
distributed
.
ProcessGroup
]
=
None
)
->
"WorkerConfig"
:
"""Returns the default worker config using torch distributed if available.
If torch distributed is not available, a single local rank is assumed."""
if
torch
.
distributed
.
is_available
()
and
torch
.
distributed
.
is_initialized
():
rank
=
torch
.
distributed
.
get_rank
(
data_parallel_group
)
world_size
=
torch
.
distributed
.
get_world_size
(
data_parallel_group
)
else
:
rank
=
0
world_size
=
1
return
WorkerConfig
(
rank
=
rank
,
world_size
=
world_size
,
num_workers
=
num_workers
,
data_parallel_group
=
data_parallel_group
,
)
def
rank_worker_id
(
self
)
->
int
:
"""Returns the self worker id within the current rank."""
if
self
.
_worker_override_global_rank
:
assert
self
.
worker_id_offset
==
0
return
self
.
_worker_override_global_rank
%
self
.
num_workers
worker_info
=
torch
.
utils
.
data
.
get_worker_info
()
if
worker_info
is
None
:
return
self
.
worker_id_offset
assert
worker_info
.
num_workers
==
self
.
num_workers
# Apply the worker_id_offset as a left rotation of the logical worker ids.
# This ensures that after restoring a checkpoint the first physical
# worker (id=0) corresponds to the logical worker that should emit the
# next sample. For example, if `worker_id_offset` is 1, logical worker
# 1 becomes the first to emit a sample, shifting the ordering forward.
return
(
worker_info
.
id
+
self
.
worker_id_offset
)
%
worker_info
.
num_workers
def
assert_worker
(
self
):
"""Checks if the current process is a worker (if configured so), and that the workers are
properly configured."""
if
self
.
num_workers
<=
1
:
assert
self
.
rank_worker_id
()
==
0
else
:
worker_info
=
torch
.
utils
.
data
.
get_worker_info
()
assert
worker_info
is
not
None
,
"Cannot iterate out of worker context"
assert
worker_info
.
num_workers
==
self
.
num_workers
,
(
f
"Actual number of workers for this rank (
{
worker_info
.
num_workers
}
) does not "
f
"match the configured number of workers (
{
self
.
num_workers
}
)"
)
def
global_worker_id
(
self
,
override_local_worker_id
:
Optional
[
int
]
=
None
)
->
int
:
"""Returns the global worker index by multiplying the rank with the number of workers.
Alternatively, you can override the local worker id.
Args:
override_local_worker_id (int, optional): The local worker id to override. None means
the current worker, which is the default.
"""
if
self
.
_worker_override_global_rank
is
not
None
:
assert
override_local_worker_id
is
None
return
self
.
_worker_override_global_rank
if
override_local_worker_id
is
not
None
:
return
self
.
rank
*
self
.
num_workers
+
override_local_worker_id
else
:
self
.
assert_worker
()
return
self
.
rank
*
self
.
num_workers
+
self
.
rank_worker_id
()
def
worker_seed
(
self
,
override_local_worker_id
:
Optional
[
int
]
=
None
)
->
int
:
"""Returns the seed for the current worker (or a specified worker).
Base on the current worker id and the seed offset, compute a seed.
Alternatively, you can override the local worker id with a fixed one to
pregenerate seeds for multiple workers.
Args:
override_local_worker_id (int, optional): The local worker id to override. None means
the current worker, which is the default.
"""
if
self
.
num_workers
==
0
:
# If we are not using workers, different ranks should still get a different seed
global_worker_id
=
self
.
rank
else
:
global_worker_id
=
self
.
global_worker_id
(
override_local_worker_id
)
seed_offset
=
self
.
seed_offset
seed_hash
=
hashlib
.
sha1
(
f
"
{
global_worker_id
}
,
{
seed_offset
}
"
.
encode
(
"utf-8"
)).
digest
()
return
int
.
from_bytes
(
seed_hash
,
byteorder
=
"big"
,
signed
=
False
)
&
0xFFFFFFFF
def
config
(
self
)
->
Dict
[
str
,
Any
]:
return
{
"rank"
:
self
.
rank
,
"world_size"
:
self
.
world_size
,
"num_workers"
:
self
.
num_workers
,
"data_parallel_group"
:
(
self
.
data_parallel_group
.
size
()
if
self
.
data_parallel_group
else
None
),
}
def
should_log
(
self
,
level
:
int
)
->
bool
:
return
level
<=
self
.
worker_log_level
def
worker_log
(
self
,
data
:
dict
)
->
None
:
"""Logs the given data to the worker debug file."""
if
self
.
worker_debug_path
is
None
:
print
(
json
.
dumps
(
data
)
+
"
\n
"
,
end
=
""
,
flush
=
True
)
else
:
in_worker
=
torch
.
utils
.
data
.
get_worker_info
()
is
not
None
# Additional "worker" with rank_worker_id=0 is the main process. All workers have +1
# as their worker_id.
worker_id
=
(
self
.
rank
*
(
self
.
num_workers
+
1
)
+
self
.
rank_worker_id
()
+
(
1
if
in_worker
else
0
)
)
if
self
.
_worker_debug_file
is
None
or
self
.
_worker_debug_file_worker_id
!=
worker_id
:
if
self
.
_worker_debug_file
is
not
None
:
self
.
_worker_debug_file
.
close
()
path
=
Path
(
self
.
worker_debug_path
.
format
(
worker_id
=
worker_id
,
pid
=
multiprocessing
.
current_process
().
ident
)
)
path
.
parent
.
mkdir
(
exist_ok
=
True
,
parents
=
True
)
self
.
_worker_debug_file
=
path
.
open
(
"w"
)
self
.
_worker_debug_file_worker_id
=
worker_id
self
.
_worker_debug_file
.
write
(
json
.
dumps
(
data
)
+
"
\n
"
)
self
.
_worker_debug_file
.
flush
()
Megatron-Energon/src/megatron/energon/wrappers/__init__.py
0 → 100644
View file @
f356f546
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from
megatron.energon.wrappers.batch_dataset
import
BatchDataset
from
megatron.energon.wrappers.blend_dataset
import
BlendDataset
from
megatron.energon.wrappers.concat_dataset
import
ConcatDataset
from
megatron.energon.wrappers.epochize_dataset
import
EpochizeDataset
from
megatron.energon.wrappers.filter_dataset
import
FilterDataset
from
megatron.energon.wrappers.gc_dataset
import
GcDataset
from
megatron.energon.wrappers.group_batch_dataset
import
GroupBatchDataset
from
megatron.energon.wrappers.iter_map_dataset
import
IterMapDataset
from
megatron.energon.wrappers.limit_dataset
import
LimitDataset
from
megatron.energon.wrappers.log_sample_dataset
import
LogSampleDataset
from
megatron.energon.wrappers.map_dataset
import
MapDataset
from
megatron.energon.wrappers.mix_batch_dataset
import
(
MixBatchDataset
,
concat_pad
,
generic_concat
,
homogeneous_concat_mix
,
)
from
megatron.energon.wrappers.packing_dataset
import
PackingDataset
from
megatron.energon.wrappers.repeat_dataset
import
RepeatDataset
from
megatron.energon.wrappers.shuffle_buffer_dataset
import
ShuffleBufferDataset
from
megatron.energon.wrappers.skip
import
SkipSample
__all__
=
[
"BatchDataset"
,
"BlendDataset"
,
"ConcatDataset"
,
"EpochizeDataset"
,
"FilterDataset"
,
"GcDataset"
,
"GroupBatchDataset"
,
"IterMapDataset"
,
"LimitDataset"
,
"LogSampleDataset"
,
"MapDataset"
,
"MixBatchDataset"
,
"RepeatDataset"
,
"ShuffleBufferDataset"
,
"SkipSample"
,
"PackingDataset"
,
"concat_pad"
,
"generic_concat"
,
"homogeneous_concat_mix"
,
]
Megatron-Energon/src/megatron/energon/wrappers/_log_exception.py
0 → 100644
View file @
f356f546
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
import
traceback
from
typing
import
Optional
,
Sequence
from
megatron.energon.source_info
import
SourceInfo
def
log_exception
(
_e
:
Exception
,
sample
,
sources
:
Optional
[
Sequence
[
SourceInfo
]]
=
None
):
traceback
.
print_exc
()
print
(
"-"
*
10
)
if
sources
:
print
(
"Sources:"
)
for
source
in
sources
:
print
(
f
" -
{
source
.
dataset_path
}
[
{
source
.
index
}
]
{
source
.
shard_name
}{
source
.
file_names
!
r
}
"
)
print
(
"-"
*
10
)
sample_str
=
str
(
sample
)
if
len
(
sample_str
)
>
400
:
sample_str
=
sample_str
[:
200
]
+
"..."
+
sample_str
[
-
200
:]
print
(
sample_str
)
print
(
"-"
*
10
)
Megatron-Energon/src/megatron/energon/wrappers/base.py
0 → 100644
View file @
f356f546
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from
abc
import
ABC
,
abstractmethod
from
contextlib
import
contextmanager
from
typing
import
Any
,
Generator
,
Generic
,
Iterable
,
Optional
,
Tuple
,
Type
,
TypeVar
,
Union
from
megatron.energon.flavors.base_dataset
import
(
FlexState
,
Sample
,
SavableDataset
,
add_sample_restore_key
,
)
from
megatron.energon.savable
import
Savable
from
megatron.energon.worker
import
WorkerConfig
T
=
TypeVar
(
"T"
)
T_sample
=
TypeVar
(
"T_sample"
,
covariant
=
True
)
T_sample_out
=
TypeVar
(
"T_sample_out"
,
covariant
=
True
)
T_sample_in
=
TypeVar
(
"T_sample_in"
,
covariant
=
True
)
class
BaseWrapperDataset
(
SavableDataset
[
T_sample_out
],
Generic
[
T_sample_in
,
T_sample_out
],
ABC
):
"""Base class for dataset wrappers. All dataset wrappers should derive from this. A dataset
wrapper takes one dataset and modifies its samples to make a new dataset. This can be for
shuffling samples or applying custom functions to the data. Some wrappers only modify the
length of the dataset or how it's repeated."""
datasets
:
Tuple
[
SavableDataset
[
T_sample_in
],
...]
def
__init__
(
self
,
datasets
:
Union
[
SavableDataset
[
T_sample_in
],
Iterable
[
SavableDataset
[
T_sample_in
]]],
*
,
worker_config
:
WorkerConfig
,
):
super
().
__init__
(
worker_config
=
worker_config
)
if
isinstance
(
datasets
,
SavableDataset
):
self
.
datasets
=
(
datasets
,)
else
:
self
.
datasets
=
tuple
(
datasets
)
for
d
in
self
.
datasets
:
# Check that the dataset worker configs are the same as the wrapper worker config
assert
d
.
worker_config
==
self
.
worker_config
,
(
"Dataset and wrapper worker configs must match."
)
@
property
def
dataset
(
self
)
->
SavableDataset
:
"""Convenience property, if only one dataset is wrapped."""
assert
len
(
self
.
datasets
)
==
1
return
self
.
datasets
[
0
]
def
can_restore_sample
(
self
)
->
bool
:
return
all
(
ds
.
can_restore_sample
()
for
ds
in
self
.
datasets
)
def
assert_can_restore
(
self
)
->
None
:
for
ds
in
self
.
datasets
:
ds
.
assert_can_restore
()
def
worker_has_samples
(
self
)
->
bool
:
return
any
(
ds
.
worker_has_samples
()
for
ds
in
self
.
datasets
)
def
_find_wrapped_dataset
(
self
,
cls
:
Type
[
SavableDataset
])
->
Optional
[
SavableDataset
]:
"""Find the outermost dataset wrapped in this dataset that is of type cls."""
for
ds
in
self
.
datasets
:
if
isinstance
(
ds
,
cls
):
return
ds
elif
isinstance
(
ds
,
BaseWrapperDataset
):
res
=
ds
.
_find_wrapped_dataset
(
cls
)
if
res
is
not
None
:
return
res
return
None
def
restore_sample
(
self
,
restore_key
:
Tuple
[
Union
[
str
,
int
,
tuple
],
...])
->
T_sample_out
:
if
len
(
self
.
datasets
)
==
1
:
return
self
.
datasets
[
0
].
restore_sample
(
restore_key
)
else
:
id
,
ds_idx
=
restore_key
[:
2
]
assert
id
==
type
(
self
).
__name__
restore_key
=
restore_key
[
2
:]
assert
isinstance
(
ds_idx
,
int
)
return
add_sample_restore_key
(
self
.
datasets
[
ds_idx
].
restore_sample
(
restore_key
),
ds_idx
,
src
=
self
,
)
def
save_state
(
self
)
->
FlexState
:
own_state
=
super
().
save_state
()
return
FlexState
(
datasets
=
[
ds
.
save_state
()
for
ds
in
self
.
datasets
],
**
own_state
)
def
restore_state
(
self
,
state
:
FlexState
)
->
None
:
assert
len
(
self
.
datasets
)
==
len
(
state
[
"datasets"
])
for
dataset
,
dstate
in
zip
(
self
.
datasets
,
state
[
"datasets"
]):
dataset
.
restore_state
(
dstate
)
super
().
restore_state
(
state
)
def
reset_state_deep
(
self
)
->
None
:
"""Resets the state of the inner datasets and then the own state."""
for
ds
in
self
.
datasets
:
if
isinstance
(
ds
,
BaseWrapperDataset
):
ds
.
reset_state_deep
()
else
:
ds
.
reset_state_own
()
self
.
reset_state_own
()
@
abstractmethod
def
reset_state_own
(
self
)
->
None
:
"""Resets the state of the dataset, excl. the inner datasets."""
...
class
SampleIndex
(
Savable
):
"""A simple class to hold the sample index for one worker."""
worker_config
:
WorkerConfig
current_idx
:
int
actives
=
0
def
__init__
(
self
,
worker_config
:
WorkerConfig
,
*
,
src
:
Any
)
->
None
:
self
.
worker_config
=
worker_config
self
.
current_idx
=
0
self
.
src
=
src
def
get_next
(
self
)
->
int
:
res
=
self
.
current_idx
self
.
current_idx
+=
1
return
res
@
contextmanager
def
ctx
(
self
,
sample_idx
:
Optional
[
int
]
=
None
):
if
sample_idx
is
None
:
sample_idx
=
self
.
get_next
()
assert
WorkerConfig
.
active_worker_config
is
not
None
WorkerConfig
.
active_worker_config
.
worker_push_sample_index
(
sample_idx
)
# print(" " * SampleIndex.actives + f"Activated from {type(self.src).__name__}({id(self.src)}) {sample_idx} -> {WorkerConfig.active_worker_config._sample_index_stack}")
SampleIndex
.
actives
+=
1
try
:
yield
sample_idx
finally
:
assert
WorkerConfig
.
active_worker_config
is
not
None
popped
=
WorkerConfig
.
active_worker_config
.
worker_pop_sample_index
()
SampleIndex
.
actives
-=
1
# print(" " * SampleIndex.actives + f"Deactivate from {type(self.src).__name__}({id(self.src)}) {sample_idx} -> {WorkerConfig.active_worker_config._sample_index_stack}")
assert
popped
==
sample_idx
,
f
"Expected
{
sample_idx
}
, got
{
popped
}
"
def
iter_ctx
(
self
,
it
:
Iterable
[
T_sample
],
sample_idx
:
Optional
[
int
]
=
None
,
)
->
Generator
[
Tuple
[
int
,
T_sample
],
None
,
None
]:
it
=
iter
(
it
)
try
:
while
True
:
try
:
with
self
.
ctx
(
sample_idx
)
as
res_sample_idx
:
x
=
next
(
it
)
yield
res_sample_idx
,
x
except
StopIteration
:
break
finally
:
if
hasattr
(
it
,
"close"
):
it
.
close
()
def
save_state
(
self
)
->
int
:
return
self
.
current_idx
def
restore_state
(
self
,
state
:
Optional
[
int
])
->
None
:
if
state
is
None
:
self
.
current_idx
=
0
else
:
self
.
current_idx
=
state
def
get_sample_restore_key
(
sample
:
Any
)
->
Optional
[
Union
[
str
,
int
]]:
"""Gets the restore key from an arbitrary sample."""
if
isinstance
(
sample
,
Sample
)
or
hasattr
(
sample
,
"__restore_key__"
):
return
sample
.
__restore_key__
elif
isinstance
(
sample
,
dict
)
and
"__restore_key__"
in
sample
:
return
sample
[
"__restore_key__"
]
else
:
return
None
Megatron-Energon/src/megatron/energon/wrappers/batch_dataset.py
0 → 100644
View file @
f356f546
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
import
inspect
from
typing
import
(
Any
,
Callable
,
Dict
,
Generator
,
Generic
,
Iterator
,
List
,
Optional
,
Sequence
,
Tuple
,
TypeVar
,
Union
,
)
from
megatron.energon.errors
import
SYSTEM_EXCEPTIONS
,
FatalSampleError
from
megatron.energon.flavors.base_dataset
import
SavableDataset
,
set_sample_restore_key
from
megatron.energon.source_info
import
SourceInfo
from
megatron.energon.worker
import
WorkerConfig
from
megatron.energon.wrappers._log_exception
import
log_exception
from
megatron.energon.wrappers.base
import
BaseWrapperDataset
,
SampleIndex
,
get_sample_restore_key
from
megatron.energon.wrappers.skip
import
SkipSample
T_batch
=
TypeVar
(
"T_batch"
,
covariant
=
True
)
T_batch_sample
=
TypeVar
(
"T_batch_sample"
,
covariant
=
True
)
class
BatchDataset
(
BaseWrapperDataset
[
T_batch_sample
,
T_batch
],
Generic
[
T_batch_sample
,
T_batch
]):
"""This dataset wrapper transforms a dataset of samples into a dataset of batches."""
batch_size
:
int
batcher
:
Callable
[[
List
[
T_batch_sample
]],
T_batch
]
drop_last
:
bool
error_handler
:
Callable
[[
Exception
,
list
[
T_batch_sample
],
Sequence
[
SourceInfo
]],
None
]
_sample_index
:
SampleIndex
_generator_sample_keys
:
Optional
[
Any
]
_generator_offset
:
Optional
[
int
]
_last_batch_failures
:
int
=
0
_savable_fields
=
(
"_sample_index"
,
"_generator_sample_keys"
,
"_generator_offset"
)
def
__init__
(
self
,
dataset
:
SavableDataset
[
T_batch_sample
],
batch_size
:
int
,
batcher
:
Callable
[[
List
[
T_batch_sample
]],
T_batch
],
*
,
batcher_stateless
:
bool
=
False
,
batcher_config
:
Optional
[
Union
[
Dict
[
str
,
Any
],
Callable
[[],
Dict
[
str
,
Any
]]]]
=
None
,
drop_last
:
bool
=
False
,
error_handler
:
Callable
[
[
Exception
,
list
[
T_batch_sample
],
Sequence
[
SourceInfo
]],
None
]
=
log_exception
,
failure_tolerance
:
int
=
100
,
worker_config
:
WorkerConfig
,
):
"""Construct a BatchDataset.
Args:
dataset: The input dataset to wrap
batch_size: The desired batch size. The last batch may be smaller.
batcher: Function which combines separate samples into a single object. May raise
:exc:`megatron.energon.SkipSample` to skip a sample.
batcher_stateless: If True, the batcher is stateless, thus samples can be stored/
restored.
batcher_config: Configuration for the batcher function. If callable, it should return the
configuration. Defaults to None.
drop_last: If True, the last batch is dropped if it is smaller than the batch size.
error_handler: Function which handles exceptions raised by the batcher. The default
implementation logs the exception.
failure_tolerance: The number of consecutive failures after which the dataset is considered broken. Set to 0 to disable.
worker_config: Configuration for the workers.
"""
super
().
__init__
(
dataset
,
worker_config
=
worker_config
)
self
.
batch_size
=
batch_size
self
.
batcher
=
batcher
self
.
batcher_stateless
=
batcher_stateless
self
.
batcher_config
=
batcher_config
self
.
drop_last
=
drop_last
self
.
error_handler
=
error_handler
self
.
failure_tolerance
=
failure_tolerance
self
.
reset_state_own
()
def
reset_state_own
(
self
)
->
None
:
self
.
_sample_index
=
SampleIndex
(
self
.
worker_config
,
src
=
self
)
self
.
_generator_sample_keys
=
None
self
.
_generator_offset
=
None
def
len_worker
(
self
,
worker_idx
:
int
|
None
=
None
)
->
int
:
n_samples
=
self
.
dataset
.
len_worker
(
worker_idx
)
n_batches
=
n_samples
//
self
.
batch_size
if
n_samples
%
self
.
batch_size
!=
0
and
not
self
.
drop_last
:
n_batches
+=
1
return
n_batches
def
__iter__
(
self
)
->
Iterator
[
T_batch
]:
batch
:
List
[
T_batch_sample
]
=
[]
sample_restore_keys
=
[]
if
self
.
_generator_sample_keys
is
not
None
:
sample_restore_keys
=
self
.
_generator_sample_keys
assert
self
.
_generator_offset
is
not
None
batch
=
[
self
.
dataset
.
restore_sample
(
inner_idx
)
for
inner_idx
in
sample_restore_keys
]
with
self
.
_sample_index
.
ctx
(
self
.
_sample_index
.
current_idx
)
as
sample_idx
:
batch_sample
=
self
.
batcher
(
batch
)
assert
isinstance
(
batch_sample
,
Generator
)
assert
inspect
.
isgeneratorfunction
(
self
.
batcher
),
(
f
"Generator in
{
self
.
batcher
}
but not marked as such."
)
target_offset
=
self
.
_generator_offset
self
.
_generator_offset
=
0
for
batch_sub_idx
,
(
sample_idx
,
inner_batch_sample
)
in
enumerate
(
self
.
_sample_index
.
iter_ctx
(
batch_sample
,
sample_idx
)
):
# Skip other samples
if
batch_sub_idx
>=
target_offset
:
self
.
_generator_offset
=
batch_sub_idx
+
1
yield
set_sample_restore_key
(
inner_batch_sample
,
sample_idx
,
batch_sub_idx
,
*
sample_restore_keys
,
src
=
self
,
)
self
.
_generator_sample_keys
=
None
self
.
_generator_offset
=
None
batch
.
clear
()
sample_restore_keys
=
[]
def
flush
()
->
Generator
[
T_batch
,
None
,
None
]:
try
:
with
self
.
_sample_index
.
ctx
()
as
sample_idx
:
batch_sample
=
self
.
batcher
(
batch
)
if
isinstance
(
batch_sample
,
Generator
):
assert
inspect
.
isgeneratorfunction
(
self
.
batcher
),
(
f
"Generator in
{
self
.
batcher
}
but not marked as such."
)
self
.
_generator_sample_keys
=
sample_restore_keys
self
.
_generator_offset
=
0
for
batch_sub_idx
,
(
sample_idx
,
inner_batch_sample
)
in
enumerate
(
self
.
_sample_index
.
iter_ctx
(
batch_sample
,
sample_idx
)
):
self
.
_last_batch_failures
=
0
self
.
_generator_offset
=
batch_sub_idx
+
1
yield
set_sample_restore_key
(
inner_batch_sample
,
sample_idx
,
batch_sub_idx
,
*
sample_restore_keys
,
src
=
self
,
)
self
.
_generator_sample_keys
=
None
self
.
_generator_offset
=
None
else
:
self
.
_last_batch_failures
=
0
set_sample_restore_key
(
batch_sample
,
sample_idx
,
*
sample_restore_keys
,
src
=
self
)
yield
batch_sample
except
GeneratorExit
:
raise
except
SkipSample
:
pass
except
SYSTEM_EXCEPTIONS
:
raise
FatalSampleError
.
from_sample
(
batch
)
except
Exception
as
e
:
self
.
error_handler
(
e
,
batch
)
self
.
_last_batch_failures
+=
1
if
(
self
.
failure_tolerance
>
0
and
self
.
_last_batch_failures
>=
self
.
failure_tolerance
):
raise
FatalSampleError
.
from_sample
(
batch
,
f
"BatchDataset
{
self
.
batcher
}
failed
{
self
.
_last_batch_failures
}
times in a row. Likely your code or dataset are broken."
,
)
finally
:
sample_restore_keys
.
clear
()
for
sample
in
self
.
dataset
:
batch
.
append
(
sample
)
sample_restore_keys
.
append
(
get_sample_restore_key
(
sample
))
if
len
(
batch
)
==
self
.
batch_size
:
yield
from
flush
()
batch
=
[]
if
len
(
batch
)
>
0
and
not
self
.
drop_last
:
yield
from
flush
()
def
can_restore_sample
(
self
)
->
bool
:
# Cannot really verify if the returned elements contain a __restore_key__.
# If the user wants to use this, well...
return
super
().
can_restore_sample
()
and
self
.
batcher_stateless
def
assert_can_restore
(
self
)
->
None
:
assert
self
.
batcher_stateless
,
(
f
"Batcher
{
self
.
batcher
}
must be stateless to restore samples"
)
super
().
assert_can_restore
()
def
restore_sample
(
self
,
restore_key
:
Tuple
[
Union
[
str
,
int
,
tuple
],
...])
->
T_batch
:
# We need to store multiple indices to restore a batch.
self
.
assert_can_restore
()
if
inspect
.
isgeneratorfunction
(
self
.
batcher
):
id
,
sample_idx
,
batch_sub_idx
,
*
samples_restore_keys
=
restore_key
assert
id
==
type
(
self
).
__name__
else
:
id
,
sample_idx
,
*
samples_restore_keys
=
restore_key
assert
id
==
type
(
self
).
__name__
batch
=
[
self
.
dataset
.
restore_sample
(
inner_idx
)
for
inner_idx
in
samples_restore_keys
]
try
:
with
self
.
_sample_index
.
ctx
(
sample_idx
):
batch_sample
=
self
.
batcher
(
batch
)
if
isinstance
(
batch_sample
,
Generator
):
assert
inspect
.
isgeneratorfunction
(
self
.
batcher
),
(
f
"Generator in
{
self
.
batcher
}
but not marked as such."
)
for
cur_batch_sub_idx
,
(
sample_idx
,
inner_batch_sample
)
in
enumerate
(
self
.
_sample_index
.
iter_ctx
(
batch_sample
,
sample_idx
)
):
self
.
_last_batch_failures
=
0
if
cur_batch_sub_idx
==
batch_sub_idx
:
return
set_sample_restore_key
(
inner_batch_sample
,
sample_idx
,
batch_sub_idx
,
*
samples_restore_keys
,
src
=
self
,
)
assert
False
,
f
"Batch sub-index
{
batch_sub_idx
}
not found in batch"
else
:
self
.
_last_batch_failures
=
0
return
set_sample_restore_key
(
batch_sample
,
sample_idx
,
*
samples_restore_keys
,
src
=
self
,
)
except
GeneratorExit
:
raise
FatalSampleError
.
from_sample
(
batch
,
f
"BatchDataset
{
self
.
batcher
}
generator exitedwhile trying to restore a batch."
,
)
except
SkipSample
:
raise
FatalSampleError
.
from_sample
(
batch
,
f
"BatchDataset
{
self
.
batcher
}
skipped while trying to restore a batch."
)
except
SYSTEM_EXCEPTIONS
:
raise
FatalSampleError
.
from_sample
(
batch
)
except
Exception
as
e
:
self
.
error_handler
(
e
,
batch
)
self
.
_last_batch_failures
+=
1
if
self
.
failure_tolerance
>
0
and
self
.
_last_batch_failures
>=
self
.
failure_tolerance
:
raise
FatalSampleError
.
from_sample
(
batch
,
f
"BatchDataset
{
self
.
batcher
}
failed
{
self
.
_last_batch_failures
}
times in a row. Likely your code or dataset are broken."
,
)
def
config
(
self
)
->
Dict
[
str
,
Any
]:
return
{
"type"
:
type
(
self
).
__qualname__
,
"batch_size"
:
self
.
batch_size
,
"batcher"
:
self
.
_function_config
(
self
.
batcher
),
**
(
{
"batcher_config"
:
(
self
.
batcher_config
()
if
callable
(
self
.
batcher_config
)
else
self
.
batcher_config
)
}
if
self
.
batcher_config
else
{}
),
"batcher_stateless"
:
self
.
batcher_stateless
,
"drop_last"
:
self
.
drop_last
,
"error_handler"
:
self
.
_function_config
(
self
.
error_handler
),
"worker_config"
:
self
.
worker_config
.
config
(),
"dataset"
:
self
.
dataset
.
config
(),
}
def
__str__
(
self
):
return
f
"BatchDataset(batch_size=
{
self
.
batch_size
}
, drop_last=
{
self
.
drop_last
}
, batcher=
{
self
.
batcher
}
, dataset=
{
self
.
dataset
}
)"
Megatron-Energon/src/megatron/energon/wrappers/blend_dataset.py
0 → 100644
View file @
f356f546
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from
typing
import
Any
,
Dict
,
Iterator
,
List
,
Sequence
,
Tuple
,
TypeVar
import
torch
from
megatron.energon.flavors.base_dataset
import
SavableDataset
,
add_sample_restore_key
from
megatron.energon.rng
import
WorkerRng
from
megatron.energon.worker
import
WorkerConfig
from
megatron.energon.wrappers.base
import
BaseWrapperDataset
T_sample
=
TypeVar
(
"T_sample"
)
class
BlendDataset
(
BaseWrapperDataset
[
T_sample
,
T_sample
]):
"""
This dataset wrapper blends multiple iterable datasets together give a weighting.
The datasets may be infinite. This dataset is always infinite.
"""
datasets
:
List
[
SavableDataset
[
T_sample
]]
weights
:
Tuple
[
float
,
...]
dataset_weights
:
Sequence
[
Tuple
[
SavableDataset
[
T_sample
],
float
]]
exhausted
:
List
[
bool
]
_worker_rng
:
WorkerRng
_savable_fields
=
(
"exhausted"
,
"_worker_rng"
)
def
__init__
(
self
,
*
dataset_weights
:
Tuple
[
SavableDataset
[
T_sample
],
float
],
worker_config
:
WorkerConfig
,
):
"""Construct a BlendDataset.
Args:
dataset_weights: Each argument should be a tuple of (dataset, weight) with a weight
between 0 and 1. The output samples are sampled from the input datasets with the
given probabilities.
worker_config: Configuration for the workers.
"""
# datasets = [dataset for dataset, _weight in dataset_weights]
self
.
datasets
,
self
.
weights
=
zip
(
*
dataset_weights
)
super
().
__init__
(
self
.
datasets
,
worker_config
=
worker_config
)
self
.
dataset_weights
=
dataset_weights
self
.
reset_state_own
()
def
reset_state_own
(
self
)
->
None
:
self
.
_worker_rng
=
WorkerRng
(
self
.
worker_config
)
self
.
exhausted
=
[
False
]
*
len
(
self
.
weights
)
def
len_worker
(
self
,
worker_idx
:
int
|
None
=
None
)
->
int
:
# Give the number of samples in inner datasets, disregarding the weight
return
sum
(
dataset
.
len_worker
(
worker_idx
)
for
dataset
in
self
.
datasets
)
def
__iter__
(
self
)
->
Iterator
[
T_sample
]:
assert
self
.
worker_has_samples
(),
"Cannot blend all empty datasets"
# Create a list of datasets and their weights, but
# set the weight to 0 if the dataset has no samples on this worker.
dataset_iters
=
[]
weights
=
[]
for
idx
,
(
dataset
,
weight
)
in
enumerate
(
self
.
dataset_weights
):
assert
weight
>
0
,
"All blending weights must be > 0"
if
dataset
.
worker_has_samples
():
dataset_iters
.
append
(
iter
(
dataset
))
weights
.
append
(
weight
)
else
:
dataset_iters
.
append
(
None
)
weights
.
append
(
0
)
self
.
exhausted
[
idx
]
=
True
weights
=
torch
.
tensor
(
weights
,
dtype
=
torch
.
float32
)
if
weights
.
sum
()
==
0
:
raise
RuntimeError
(
"There is a worker with no samples in any of the blended datasets. "
"This can happen if you have a lot of workers and your dataset is too small. "
"Currently this case is not supported."
)
# Some may already be exhausted on this worker when restoring a state.
for
idx
,
exhausted
in
enumerate
(
self
.
exhausted
):
if
exhausted
:
weights
[
idx
]
=
0
dataset_iters
[
idx
]
=
None
while
True
:
ds_idx
=
self
.
_worker_rng
.
choice_idx
(
probs
=
weights
)
if
dataset_iters
[
ds_idx
]
is
None
:
if
all
(
dataset_iter
is
None
for
dataset_iter
in
dataset_iters
):
break
continue
try
:
sample
=
next
(
dataset_iters
[
ds_idx
])
except
StopIteration
:
dataset_iters
[
ds_idx
]
=
None
weights
[
ds_idx
]
=
0
self
.
exhausted
[
ds_idx
]
=
True
if
all
(
dataset_iter
is
None
for
dataset_iter
in
dataset_iters
):
break
else
:
yield
add_sample_restore_key
(
sample
,
ds_idx
,
src
=
self
)
self
.
exhausted
=
[
False
]
*
len
(
self
.
dataset_weights
)
def
config
(
self
)
->
Dict
[
str
,
Any
]:
return
{
"type"
:
type
(
self
).
__qualname__
,
"dataset_weights"
:
[
(
dataset
.
config
(),
weight
)
for
dataset
,
weight
in
self
.
dataset_weights
],
"worker_config"
:
self
.
worker_config
.
config
(),
}
def
__str__
(
self
):
return
f
"BlendDataset(dataset_weights=
{
self
.
dataset_weights
}
)"
Megatron-Energon/src/megatron/energon/wrappers/buffer.py
0 → 100644
View file @
f356f546
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from
typing
import
(
Any
,
Dict
,
Generator
,
Generic
,
Iterator
,
List
,
Optional
,
Sequence
,
Tuple
,
TypeVar
,
Union
,
)
from
megatron.energon.flavors.base_dataset
import
FlexState
,
SavableDataset
from
megatron.energon.worker
import
WorkerConfig
from
megatron.energon.wrappers.base
import
BaseWrapperDataset
,
get_sample_restore_key
T_sample
=
TypeVar
(
"T_sample"
)
class
SavableSampleBuffer
(
BaseWrapperDataset
[
T_sample
,
T_sample
],
Generic
[
T_sample
]):
"""A buffer of samples, savable."""
_buffer
:
List
[
T_sample
]
_restore_keys
:
List
[
Tuple
[
Union
[
str
,
int
,
tuple
],
...]]
_savable_fields
=
(
"_restore_keys"
,)
_restore_pending
:
bool
=
False
def
__init__
(
self
,
dataset
:
SavableDataset
[
T_sample
],
*
,
worker_config
:
WorkerConfig
):
super
().
__init__
(
dataset
,
worker_config
=
worker_config
)
self
.
reset_state_own
()
def
reset_state_own
(
self
)
->
None
:
self
.
_buffer
=
[]
self
.
_restore_keys
=
[]
def
worker_start
(
self
)
->
None
:
if
self
.
_restore_pending
:
assert
len
(
self
.
_buffer
)
==
0
self
.
_restore_pending
=
False
for
restore_key
in
self
.
_restore_keys
:
self
.
_buffer
.
append
(
self
.
restore_sample
(
restore_key
))
assert
len
(
self
.
_buffer
)
==
len
(
self
.
_restore_keys
)
def
append
(
self
,
sample
:
T_sample
)
->
T_sample
:
self
.
_buffer
.
append
(
sample
)
self
.
_restore_keys
.
append
(
get_sample_restore_key
(
sample
))
return
sample
def
extend
(
self
,
samples
:
List
[
T_sample
],
restore_keys
:
Optional
[
Sequence
[
Any
]]
=
None
)
->
None
:
self
.
_buffer
.
extend
(
samples
)
if
restore_keys
is
None
:
self
.
_restore_keys
.
extend
(
get_sample_restore_key
(
sample
)
for
sample
in
samples
)
else
:
self
.
_restore_keys
.
extend
(
restore_keys
)
def
append_iter
(
self
)
->
Generator
[
T_sample
,
None
,
None
]:
for
sample
in
self
.
dataset
:
yield
self
.
append
(
sample
)
def
pop
(
self
,
index
:
int
)
->
T_sample
:
self
.
_restore_keys
.
pop
(
index
)
return
self
.
_buffer
.
pop
(
index
)
def
flush
(
self
)
->
Tuple
[
List
[
T_sample
],
Tuple
[
Any
,
...]]:
buffer
=
list
(
self
.
_buffer
)
restore_key
=
tuple
(
self
.
_restore_keys
)
self
.
_buffer
.
clear
()
self
.
_restore_keys
.
clear
()
return
buffer
,
restore_key
@
property
def
buffer
(
self
)
->
List
[
T_sample
]:
return
self
.
_buffer
def
__iter__
(
self
)
->
Iterator
[
T_sample
]:
return
iter
(
self
.
_buffer
)
def
__getitem__
(
self
,
index
:
Union
[
int
,
slice
])
->
Union
[
T_sample
,
List
[
T_sample
]]:
return
self
.
_buffer
[
index
]
def
__setitem__
(
self
,
index
:
Union
[
int
,
slice
],
value
:
T_sample
)
->
None
:
self
.
_buffer
[
index
]
=
value
if
isinstance
(
index
,
slice
):
self
.
_restore_keys
[
index
]
=
(
get_sample_restore_key
(
v
)
for
v
in
value
)
else
:
self
.
_restore_keys
[
index
]
=
get_sample_restore_key
(
value
)
def
__delitem__
(
self
,
index
:
Union
[
int
,
slice
])
->
None
:
del
self
.
_buffer
[
index
]
del
self
.
_restore_keys
[
index
]
def
len_worker
(
self
,
worker_idx
:
int
|
None
=
None
)
->
int
:
self
.
worker_config
.
assert_worker
()
assert
worker_idx
is
None
or
worker_idx
==
self
.
worker_config
.
rank_worker_id
(),
(
"SavableSampleBuffer.len_worker only available for the current worker"
)
return
len
(
self
.
_restore_keys
)
def
len_rank
(
self
)
->
int
:
raise
NotImplementedError
(
"len_rank is not available for SavableSampleBuffer"
)
def
save_state
(
self
)
->
FlexState
:
# Don't call super().save_state() because we don't want to save the wrapped datasets
# Just save the own state
return
SavableDataset
.
save_state
(
self
)
def
restore_state
(
self
,
state
:
FlexState
)
->
None
:
# Don't call super().restore_state() because we don't want to restore the wrapped datasets
# Just restore the own state
SavableDataset
.
restore_state
(
self
,
state
)
self
.
_restore_pending
=
True
def
restore_key
(
self
)
->
Tuple
[
Union
[
str
,
int
],
...]:
return
tuple
(
self
.
_restore_keys
)
def
restore_samples
(
self
,
index
:
Tuple
[
Union
[
str
,
int
,
tuple
],
...]
)
->
Tuple
[
Tuple
[
Union
[
str
,
int
,
tuple
],
...],
List
[
T_sample
]]:
buffer
=
[]
restore_keys
=
[]
for
sub_index
in
index
:
sample
=
self
.
restore_sample
(
sub_index
)
restore_keys
.
append
(
get_sample_restore_key
(
sample
))
buffer
.
append
(
sample
)
return
tuple
(
restore_keys
),
buffer
def
clear
(
self
)
->
None
:
self
.
_buffer
.
clear
()
self
.
_restore_keys
.
clear
()
def
config
(
self
)
->
Dict
[
str
,
Any
]:
return
{
"type"
:
type
(
self
).
__qualname__
,
"dataset"
:
self
.
dataset
.
config
(),
"worker_config"
:
self
.
worker_config
.
config
(),
}
def
debug_print
(
self
,
indent
:
str
=
""
):
print
(
f
"
{
indent
}
SavableSampleBuffer(size=
{
len
(
self
.
_restore_keys
)
}
, res_pend=
{
self
.
_restore_pending
}
):
\n
"
,
end
=
""
,
)
for
i
,
(
sample
,
restore_key
)
in
enumerate
(
zip
(
self
.
_buffer
,
self
.
_restore_keys
)):
print
(
f
"
{
indent
}
Sample
{
i
}
[
{
restore_key
!
r
}
]:
{
sample
.
__key__
}
\n
"
,
end
=
""
)
def
__str__
(
self
):
return
f
"SavableSampleBuffer(size=
{
len
(
self
.
_buffer
)
}
)"
Megatron-Energon/src/megatron/energon/wrappers/concat_dataset.py
0 → 100644
View file @
f356f546
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from
typing
import
Any
,
Dict
,
Generic
,
Iterator
,
TypeVar
from
megatron.energon.flavors.base_dataset
import
SavableDataset
,
add_sample_restore_key
from
megatron.energon.worker
import
WorkerConfig
from
megatron.energon.wrappers.base
import
BaseWrapperDataset
T_sample
=
TypeVar
(
"T_sample"
)
class
ConcatDataset
(
BaseWrapperDataset
[
T_sample
,
T_sample
],
Generic
[
T_sample
]):
"""
This dataset wrapper concatenates multiple iterable datasets together. The datasets must be
finite, otherwise not all datasets can be sampled. This is only useful for validation / test
datasets.
"""
def
__init__
(
self
,
*
datasets
:
SavableDataset
[
T_sample
],
worker_config
:
WorkerConfig
,
):
"""Construct a concatenated dataset."""
super
().
__init__
(
datasets
,
worker_config
=
worker_config
)
assert
len
(
self
)
>=
0
,
"Datasets must be finite."
def
reset_state_own
(
self
)
->
None
:
return
def
len_worker
(
self
,
worker_idx
:
int
|
None
=
None
)
->
int
:
return
sum
(
dataset
.
len_worker
(
worker_idx
)
for
dataset
in
self
.
datasets
)
def
__iter__
(
self
)
->
Iterator
[
T_sample
]:
for
ds_idx
,
dataset
in
enumerate
(
self
.
datasets
):
for
sample
in
dataset
:
yield
add_sample_restore_key
(
sample
,
ds_idx
,
src
=
self
,
)
def
config
(
self
)
->
Dict
[
str
,
Any
]:
return
{
"type"
:
type
(
self
).
__qualname__
,
"datasets"
:
[
dataset
.
config
()
for
dataset
in
self
.
datasets
],
}
def
__str__
(
self
):
return
f
"ConcatDataset(datasets=
{
self
.
datasets
}
)"
Megatron-Energon/src/megatron/energon/wrappers/epochize_dataset.py
0 → 100644
View file @
f356f546
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from
typing
import
Any
,
Dict
,
Generic
,
Iterator
,
Optional
,
TypeVar
from
megatron.energon.flavors.base_dataset
import
SavableDataset
from
megatron.energon.worker
import
WorkerConfig
from
megatron.energon.wrappers.base
import
BaseWrapperDataset
T_sample
=
TypeVar
(
"T_sample"
)
class
EpochizeDataset
(
BaseWrapperDataset
[
T_sample
,
T_sample
],
Generic
[
T_sample
]):
"""
Uses the base dataset, and creates one epoch, which has length samples. Keeps the underlying
dataset iterator alive over epochs (i.e. if it is an infinite dataset, it will keep the state).
Repeats the underlying dataset if the iterator is exhausted.
"""
length
:
int
_active_iter
:
Optional
[
Iterator
[
T_sample
]]
_offset
:
int
_savable_fields
=
(
"_offset"
,)
def
__init__
(
self
,
dataset
:
SavableDataset
[
T_sample
],
length
:
int
,
worker_config
:
WorkerConfig
,
):
"""
Create the epochized dataset.
Args:
dataset: The source dataset (possibly infinite)
length: Number of samples to iterate before iteration stops (i.e. one epoch). When
iteration continues, the original dataset iterator is resumed and does only restart
if exhausted.
worker_config: Configuration for the workers.
"""
super
().
__init__
(
dataset
,
worker_config
=
worker_config
)
self
.
length
=
length
self
.
_active_iter
=
None
self
.
reset_state_own
()
def
reset_state_own
(
self
)
->
None
:
self
.
_offset
=
0
def
__iter__
(
self
)
->
Iterator
[
T_sample
]:
# Compute the local length for this worker, i.e. all worker's lengths sum up to the total
if
self
.
worker_config
.
num_workers
<=
1
:
local_length
=
self
.
length
else
:
local_length
=
self
.
length
//
self
.
worker_config
.
num_workers
if
self
.
worker_config
.
rank_worker_id
()
<
self
.
length
%
self
.
worker_config
.
num_workers
:
local_length
+=
1
if
self
.
worker_config
.
should_log
(
level
=
2
):
self
.
worker_config
.
worker_log
(
{
"t"
:
"EpochizeDataset.epoch_start"
,
"r"
:
self
.
worker_config
.
rank
,
"w"
:
self
.
worker_config
.
rank_worker_id
(),
"offset"
:
self
.
_offset
,
"local_length"
:
local_length
,
"length"
:
self
.
length
,
}
)
offset_range
=
list
(
range
(
self
.
_offset
,
local_length
))
# Only iterate if there are samples to iterate
if
len
(
offset_range
)
>
0
:
if
self
.
_active_iter
is
None
:
self
.
_active_iter
=
iter
(
self
.
dataset
)
for
idx
in
offset_range
:
self
.
_offset
=
(
idx
+
1
)
%
local_length
try
:
sample
=
next
(
self
.
_active_iter
)
except
StopIteration
:
break
yield
sample
if
self
.
worker_config
.
should_log
(
level
=
2
):
self
.
worker_config
.
worker_log
(
{
"t"
:
"EpochizeDataset.epoch_end"
,
"r"
:
self
.
worker_config
.
rank
,
"w"
:
self
.
worker_config
.
rank_worker_id
(),
"offset"
:
self
.
_offset
,
"local_length"
:
local_length
,
"length"
:
self
.
length
,
}
)
def
len_worker
(
self
,
worker_idx
:
int
|
None
=
None
)
->
int
:
if
worker_idx
is
None
:
self
.
worker_config
.
assert_worker
()
worker_idx
=
self
.
worker_config
.
rank_worker_id
()
if
self
.
worker_config
.
num_workers
<=
1
:
assert
worker_idx
==
0
return
self
.
length
else
:
local_length
=
self
.
length
//
self
.
worker_config
.
num_workers
if
worker_idx
<
self
.
length
%
self
.
worker_config
.
num_workers
:
local_length
+=
1
return
local_length
def
config
(
self
)
->
Dict
[
str
,
Any
]:
return
{
"type"
:
type
(
self
).
__qualname__
,
"dataset"
:
self
.
dataset
.
config
(),
"length"
:
self
.
length
,
"worker_config"
:
self
.
worker_config
.
config
(),
}
def
__str__
(
self
):
return
f
"EpochizeDataset(length=
{
self
.
length
}
, dataset=
{
self
.
dataset
}
)"
Megatron-Energon/src/megatron/energon/wrappers/filter_dataset.py
0 → 100644
View file @
f356f546
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from
typing
import
Any
,
Callable
,
Dict
,
Generic
,
Iterator
,
Optional
,
TypeVar
,
Union
from
megatron.energon.flavors.base_dataset
import
SavableDataset
from
megatron.energon.worker
import
WorkerConfig
from
megatron.energon.wrappers.base
import
BaseWrapperDataset
,
SampleIndex
T_sample
=
TypeVar
(
"T_sample"
)
class
FilterDataset
(
BaseWrapperDataset
[
T_sample
,
T_sample
],
Generic
[
T_sample
]):
"""This dataset wrapper applies a custom filter function to each sample and does not yield
filtered samples."""
filter_fn
:
Callable
[[
T_sample
],
bool
]
filter_fn_config
:
Optional
[
Union
[
Dict
[
str
,
Any
],
Callable
[[],
Dict
[
str
,
Any
]]]]
_sample_index
:
SampleIndex
_savable_fields
=
(
"_sample_index"
,)
def
__init__
(
self
,
dataset
:
SavableDataset
[
T_sample
],
*
,
filter_fn
:
Callable
[[
T_sample
],
bool
],
filter_fn_config
:
Optional
[
Union
[
Dict
[
str
,
Any
],
Callable
[[],
Dict
[
str
,
Any
]]]]
=
None
,
worker_config
:
WorkerConfig
,
):
"""Construct a MapDataset.
Args:
dataset: The input dataset to wrap
filter_fn: The function to apply to each sample. If it returns `True`, the sample is
accepted.
filter_fn_config: Configuration for the filter function. If callable, it should return the
configuration. Defaults to None.
worker_config: Configuration for the workers.
"""
super
().
__init__
(
dataset
,
worker_config
=
worker_config
)
self
.
filter_fn
=
filter_fn
self
.
filter_fn_config
=
filter_fn_config
self
.
reset_state_own
()
def
reset_state_own
(
self
)
->
None
:
self
.
_sample_index
=
SampleIndex
(
self
.
worker_config
,
src
=
self
)
def
len_worker
(
self
,
worker_idx
:
int
|
None
=
None
)
->
int
:
return
self
.
dataset
.
len_worker
(
worker_idx
)
def
__iter__
(
self
)
->
Iterator
[
T_sample
]:
for
sample
in
self
.
dataset
:
with
self
.
_sample_index
.
ctx
():
filter_res
=
self
.
filter_fn
(
sample
)
if
filter_res
:
yield
sample
def
config
(
self
)
->
Dict
[
str
,
Any
]:
return
{
"type"
:
type
(
self
).
__qualname__
,
"dataset"
:
self
.
dataset
.
config
(),
"filter_fn"
:
self
.
_function_config
(
self
.
filter_fn
),
**
(
{
"filter_fn_config"
:
(
self
.
filter_fn_config
()
if
callable
(
self
.
filter_fn_config
)
else
self
.
filter_fn_config
)
}
if
self
.
filter_fn_config
else
{}
),
}
def
__str__
(
self
):
return
f
"FilterDataset(filter_fn=
{
self
.
filter_fn
}
, dataset=
{
self
.
dataset
}
)"
Megatron-Energon/src/megatron/energon/wrappers/gc_dataset.py
0 → 100644
View file @
f356f546
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
import
gc
from
typing
import
Any
,
Dict
,
Generic
,
Iterator
,
TypeVar
import
torch
import
torch.utils.data
import
torch.utils.data.dataloader
from
torch.distributed._shard.sharded_tensor
import
ShardedTensorBase
from
torch.distributed.distributed_c10d
import
reduce_op
from
megatron.energon.flavors.base_dataset
import
SavableDataset
from
megatron.energon.worker
import
WorkerConfig
from
megatron.energon.wrappers.base
import
BaseWrapperDataset
T_sample
=
TypeVar
(
"T_sample"
)
_frozen_cuda_tensors
=
set
()
_frozen_cuda_tensors_initialized
=
False
GC_DEFAULT_EVERY_N_ITER
=
10
class
GcFreezeError
(
RuntimeError
):
pass
def
gc_init_worker
(
worker_id
:
int
):
"""This function should be called by any forked worker process that uses CUDA.
It should be called as early as possible in the worker process, ideally in
the worker_init_fn of the DataLoader.
By keeping a reference to all CUDA tensors in the worker process, we can
prevent the forked tensors from being garbage collected."""
global
_frozen_cuda_tensors_initialized
,
_frozen_cuda_tensors
num_tensors
=
0
for
o
in
gc
.
get_objects
():
try
:
if
o
is
not
reduce_op
:
if
isinstance
(
o
,
torch
.
Tensor
):
if
isinstance
(
o
,
ShardedTensorBase
)
or
o
.
is_cuda
:
# Calling .is_cuda or any hasattr on ShardedTensor will raise an error
# Hence, o.is_cuda is only called if o is not a ShardedTensor (in the if above)
_frozen_cuda_tensors
.
add
(
o
)
num_tensors
+=
1
elif
isinstance
(
o
,
torch
.
utils
.
data
.
dataloader
.
_MultiProcessingDataLoaderIter
):
o
.
_shutdown
=
True
except
ReferenceError
:
# Can happen if the object is a weakref proxy, don't care
pass
_frozen_cuda_tensors_initialized
=
True
class
GcDataset
(
BaseWrapperDataset
[
T_sample
,
T_sample
],
Generic
[
T_sample
]):
"""Applies a garbage collection step. This is needed, because python garbage collection
does not work well with very large objects, such as tensors. This case happens, if there are
a few hundred objects created and released every epoch (some of them being (large) tensors),
where a lot of them are alive at the same time, but released later. In that case, those objects
may end up in gc generation 2, where they may live until a lot of objects have been created,
until automatic garbage collection of gen2 is actually triggered. To avoid this memory leak,
`gc.collect()` is best to be called regularly. In addition, if `gc.freeze()` is used before the
loop, it will remove the objects currently alive from garbage collection checks, thus making the
gc faster.
"""
every_n_iter
:
int
freeze
:
bool
def
__init__
(
self
,
dataset
:
SavableDataset
[
T_sample
],
*
,
worker_config
:
WorkerConfig
,
every_n_iter
:
int
=
GC_DEFAULT_EVERY_N_ITER
,
freeze
:
bool
=
True
,
):
"""Construct a GcDataset, which applies garbage collection every `every_n_iter` iterations.
Args:
dataset: The input dataset to wrap
every_n_iter: How often to perform garbage collection
freeze: If true, run `gc.freeze()` before the loop, and `gc.unfreeze()` after the loop.
This will speed up garbage collection, but will keep all initially alive objects
alive until the end of the loop (i.e. if the dataset state was restored, that state
will be saved as well).
"""
super
().
__init__
(
dataset
,
worker_config
=
worker_config
)
self
.
every_n_iter
=
every_n_iter
self
.
freeze
=
freeze
def
reset_state_own
(
self
)
->
None
:
return
def
len_worker
(
self
,
worker_idx
:
int
|
None
=
None
)
->
int
:
return
self
.
dataset
.
len_worker
(
worker_idx
)
def
__iter__
(
self
)
->
Iterator
[
T_sample
]:
in_worker
=
torch
.
utils
.
data
.
get_worker_info
()
is
not
None
if
in_worker
and
not
_frozen_cuda_tensors_initialized
:
raise
GcFreezeError
(
"You are using GcDataset with multiple workers, but forgot to call gc_init_worker() in at least one forked worker process."
)
if
self
.
freeze
:
gc
.
collect
()
gc
.
freeze
()
try
:
iter
=
0
for
sample
in
self
.
dataset
:
yield
sample
iter
+=
1
if
iter
>=
self
.
every_n_iter
:
gc
.
collect
()
iter
=
0
finally
:
if
self
.
freeze
:
gc
.
unfreeze
()
def
config
(
self
)
->
Dict
[
str
,
Any
]:
# This is transparent, no config to be saved (it does not affect the dataset)
return
self
.
dataset
.
config
()
def
__str__
(
self
):
return
f
"GcDataset(every_n_iter=
{
self
.
every_n_iter
}
, dataset=
{
self
.
dataset
}
, freeze=
{
self
.
freeze
}
)"
Megatron-Energon/src/megatron/energon/wrappers/group_batch_dataset.py
0 → 100644
View file @
f356f546
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
import
inspect
from
typing
import
(
Any
,
Callable
,
Dict
,
Generator
,
Generic
,
Hashable
,
Iterator
,
List
,
Optional
,
Sequence
,
Tuple
,
TypeVar
,
Union
,
)
from
megatron.energon.edataclass
import
edataclass
from
megatron.energon.errors
import
SYSTEM_EXCEPTIONS
,
FatalSampleError
from
megatron.energon.flavors.base_dataset
import
(
FlexState
,
SavableDataset
,
set_sample_restore_key
,
)
from
megatron.energon.savable
import
Savable
from
megatron.energon.source_info
import
SourceInfo
from
megatron.energon.worker
import
WorkerConfig
from
megatron.energon.wrappers._log_exception
import
log_exception
from
megatron.energon.wrappers.base
import
BaseWrapperDataset
,
SampleIndex
from
megatron.energon.wrappers.buffer
import
SavableSampleBuffer
from
megatron.energon.wrappers.skip
import
SkipSample
T_batch
=
TypeVar
(
"T_batch"
,
covariant
=
True
)
T_batch_sample
=
TypeVar
(
"T_batch_sample"
,
covariant
=
True
)
@
edataclass
class
Bucket
(
Savable
,
Generic
[
T_batch_sample
]):
batch_size
:
int
samples
:
SavableSampleBuffer
[
T_batch_sample
]
def
save_state
(
self
)
->
FlexState
:
return
FlexState
(
batch_size
=
self
.
batch_size
,
samples
=
self
.
samples
.
save_state
(),
)
def
restore_state
(
self
,
state
:
FlexState
):
self
.
batch_size
=
state
[
"batch_size"
]
self
.
samples
.
restore_state
(
state
[
"samples"
])
class
GroupBatchDataset
(
BaseWrapperDataset
[
T_batch_sample
,
T_batch
],
Generic
[
T_batch_sample
,
T_batch
]
):
"""This dataset wrapper transforms a dataset of samples into a dataset of batches, grouped by some criterion.
The length is not correct, as this function can not predict the number of batches as there is no fixed batch size,
instead it returns the inner dataset size.
An example use case is: Image-Text samples, which are to be grouped by the image size into three
size categories (e.g. 128x128, 256x256, 512x512) for efficient augmentation and batching.
"""
dataset
:
SavableDataset
[
T_batch_sample
]
sample_group_key
:
Callable
[[
T_batch_sample
],
Tuple
[
Hashable
,
Optional
[
int
]]]
batcher
:
Callable
[[
List
[
T_batch_sample
]],
T_batch
]
drop_last
:
bool
error_handler
:
Callable
[[
Exception
,
List
[
T_batch_sample
],
list
[
SourceInfo
]],
None
]
_group_key_sample_index
:
SampleIndex
_batch_sample_index
:
SampleIndex
_buckets
:
Dict
[
Hashable
,
Bucket
[
T_batch_sample
]]
_last_batch_failures
:
int
=
0
def
__init__
(
self
,
dataset
:
SavableDataset
[
T_batch_sample
],
fixed_batch_size
:
Optional
[
int
],
sample_group_key
:
Callable
[[
T_batch_sample
],
Tuple
[
Hashable
,
Optional
[
int
]]],
batcher
:
Callable
[[
List
[
T_batch_sample
]],
T_batch
],
*
,
batcher_stateless
:
bool
=
False
,
batcher_config
:
Optional
[
Union
[
Dict
[
str
,
Any
],
Callable
[[],
Dict
[
str
,
Any
]]]]
=
None
,
drop_last
:
bool
=
False
,
error_handler
:
Callable
[
[
Exception
,
List
[
T_batch_sample
],
Sequence
[
SourceInfo
]],
None
]
=
log_exception
,
failure_tolerance
:
int
=
100
,
worker_config
:
WorkerConfig
,
):
"""Construct a GroupBatchDataset.
Args:
dataset: The input dataset to wrap
fixed_batch_size: Fixed batch size to use for all buckets. If None, the batch size is determined by the sample_group_key function.
sample_group_key: Function which determines the bucket of a sample.
batcher: Function which combines separate samples into a single object. May raise
:exc:`megatron.energon.SkipSample` to skip a sample.
drop_last: If True, the last batch is dropped if it is smaller than the batch size.
error_handler: Handler for errors. Defaults to logging and ignoring the exception.
failure_tolerance: The number of consecutive failures after which the dataset is considered broken. Set to 0 to disable.
worker_config: Configuration for the workers.
"""
super
().
__init__
(
dataset
,
worker_config
=
worker_config
)
self
.
fixed_batch_size
=
fixed_batch_size
self
.
sample_group_key
=
sample_group_key
self
.
batcher
=
batcher
self
.
batcher_stateless
=
batcher_stateless
self
.
batcher_config
=
batcher_config
self
.
drop_last
=
drop_last
self
.
error_handler
=
error_handler
self
.
failure_tolerance
=
failure_tolerance
self
.
reset_state_own
()
assert
not
inspect
.
isgeneratorfunction
(
batcher
),
(
f
"Batcher
{
batcher
}
must not be a generator function for grouped batching."
)
def
reset_state_own
(
self
)
->
None
:
self
.
_group_key_sample_index
=
SampleIndex
(
self
.
worker_config
,
src
=
self
)
self
.
_batch_sample_index
=
SampleIndex
(
self
.
worker_config
,
src
=
self
)
self
.
_buckets
=
{}
def
len_worker
(
self
,
worker_idx
:
int
|
None
=
None
)
->
int
:
# Return an upper bound. This is for sure not correct.
return
self
.
dataset
.
len_worker
(
worker_idx
)
def
__iter__
(
self
)
->
Iterator
[
T_batch
]:
buckets
=
self
.
_buckets
if
buckets
is
None
:
buckets
=
self
.
_buckets
=
dict
()
# Load saved state if available
for
bucket
in
buckets
.
values
():
bucket
.
samples
.
worker_start
()
# print(f"[wrk={worker_idx}, s={self._batch_sample_index.current_idx}] initial GroupBatchDataset state:\n", end="")
# for bucket_key, bucket in buckets.items():
# print(f"[wrk={worker_idx}, s={self._batch_sample_index.current_idx}] - Bucket [{bucket_key}] (bs={bucket.batch_size}, len(samples)={len(bucket.samples)}):\n", end="")
# bucket.samples.debug_print(" ")
# print(f"[wrk={worker_idx}, s={self._batch_sample_index.current_idx}] initial done\n", end="")
def
flush
(
bucket
:
Bucket
[
T_batch_sample
])
->
Generator
[
T_batch
,
None
,
None
]:
# Debug print the state
# print(f"[wrk={worker_idx}, s={self._batch_sample_index.current_idx}] flush GroupBatchDataset state:\n", end="")
# for dbg_bucket_key, dbg_bucket in buckets.items():
# print(f"[wrk={worker_idx}, s={self._batch_sample_index.current_idx}] - Bucket [{dbg_bucket_key}{'*' if dbg_bucket_key == bucket_key else ''}] (bs={dbg_bucket.batch_size}, len(samples)={len(dbg_bucket.samples)}):\n", end="")
# dbg_bucket.samples.debug_print(" ")
batch_items
,
sample_restore_keys
=
bucket
.
samples
.
flush
()
# print(f"[wrk={worker_idx}, s={self._batch_sample_index.current_idx}] flushed: len(batch)={len(batch_items)} len(samples)={len(bucket.samples)}\n", end="")
try
:
with
self
.
_batch_sample_index
.
ctx
()
as
sample_idx
:
batch_sample
=
self
.
batcher
(
batch_items
)
assert
not
isinstance
(
batch_sample
,
Generator
),
(
f
"Batcher
{
self
.
batcher
}
returned a generator, which is not supported for grouped batching yet."
)
self
.
_last_batch_failures
=
0
set_sample_restore_key
(
batch_sample
,
sample_idx
,
*
sample_restore_keys
,
src
=
self
)
yield
batch_sample
except
SkipSample
:
pass
except
SYSTEM_EXCEPTIONS
:
raise
FatalSampleError
.
from_sample
(
batch_items
)
except
Exception
as
e
:
self
.
error_handler
(
e
,
batch_items
)
self
.
_last_batch_failures
+=
1
if
(
self
.
failure_tolerance
>
0
and
self
.
_last_batch_failures
>=
self
.
failure_tolerance
):
raise
FatalSampleError
.
from_sample
(
batch_items
,
f
"GroupBatchDataset
{
self
.
batcher
}
failed
{
self
.
_last_batch_failures
}
times in a row. Likely your code or dataset are broken."
,
)
# Add samples to the buckets
for
sample
in
self
.
dataset
:
try
:
with
self
.
_group_key_sample_index
.
ctx
():
bucket_key
,
batch_size
=
self
.
sample_group_key
(
sample
)
assert
(
batch_size
is
None
)
!=
(
self
.
fixed_batch_size
is
None
),
(
f
"A sample in group for key
{
bucket_key
}
returned batch size
{
batch_size
}
, but fixed "
f
"batch size is set to
{
self
.
fixed_batch_size
}
. One of the two should be None."
)
if
self
.
fixed_batch_size
is
not
None
:
batch_size
=
self
.
fixed_batch_size
except
SkipSample
:
continue
except
SYSTEM_EXCEPTIONS
:
raise
FatalSampleError
.
from_sample
(
sample
)
except
Exception
as
e
:
self
.
error_handler
(
e
,
[
sample
])
continue
bucket
=
buckets
.
get
(
bucket_key
)
if
bucket
is
None
:
assert
batch_size
is
not
None
buckets
[
bucket_key
]
=
bucket
=
Bucket
(
batch_size
=
batch_size
,
samples
=
SavableSampleBuffer
(
self
.
dataset
,
worker_config
=
self
.
worker_config
),
)
else
:
assert
bucket
.
batch_size
==
batch_size
,
(
f
"Got different batch size for group
{
bucket_key
}
:
{
bucket
.
batch_size
}
!=
{
batch_size
}
."
)
bucket
.
samples
.
append
(
sample
)
if
bucket
.
samples
.
len_worker
()
>=
bucket
.
batch_size
:
yield
from
flush
(
bucket
)
# Flush out last samples
if
not
self
.
drop_last
:
for
bucket
in
buckets
.
values
():
if
bucket
.
samples
.
len_worker
()
>
0
:
yield
from
flush
(
bucket
)
# Clear the buckets
self
.
_buckets
.
clear
()
def
save_state
(
self
)
->
FlexState
:
return
FlexState
(
bucket_sample_index
=
self
.
_group_key_sample_index
.
save_state
(),
batch_sample_index
=
self
.
_batch_sample_index
.
save_state
(),
buckets
=
{
key
:
bucket
.
save_state
()
for
key
,
bucket
in
self
.
_buckets
.
items
()},
**
super
().
save_state
(),
)
def
restore_state
(
self
,
state
:
FlexState
)
->
None
:
super
().
restore_state
(
state
)
self
.
_group_key_sample_index
.
restore_state
(
state
[
"bucket_sample_index"
])
self
.
_batch_sample_index
.
restore_state
(
state
[
"batch_sample_index"
])
for
key
,
bucket_state
in
state
[
"buckets"
].
items
():
self
.
_buckets
[
key
]
=
Bucket
(
batch_size
=-
1
,
samples
=
SavableSampleBuffer
(
self
.
dataset
,
worker_config
=
self
.
worker_config
),
)
self
.
_buckets
[
key
].
restore_state
(
bucket_state
)
def
can_restore_sample
(
self
)
->
bool
:
return
super
().
can_restore_sample
()
and
self
.
batcher_stateless
def
assert_can_restore
(
self
)
->
None
:
assert
self
.
batcher_stateless
,
(
f
"Batcher
{
self
.
batcher
}
must be stateless to restore samples"
)
super
().
assert_can_restore
()
def
restore_sample
(
self
,
index
:
Tuple
[
Union
[
str
,
int
,
tuple
],
...])
->
T_batch
:
self
.
assert_can_restore
()
id
,
sample_idx
,
*
sample_restore_keys
=
index
assert
id
==
type
(
self
).
__name__
batch
=
[
self
.
dataset
.
restore_sample
(
inner_idx
)
for
inner_idx
in
sample_restore_keys
]
try
:
with
self
.
_batch_sample_index
.
ctx
(
sample_idx
):
batch_sample
=
self
.
batcher
(
batch
)
set_sample_restore_key
(
batch_sample
,
sample_idx
,
*
sample_restore_keys
,
src
=
self
)
self
.
_last_batch_failures
=
0
except
SkipSample
:
pass
except
SYSTEM_EXCEPTIONS
:
raise
FatalSampleError
.
from_sample
(
batch
)
except
Exception
as
e
:
self
.
error_handler
(
e
,
batch
)
self
.
_last_batch_failures
+=
1
if
self
.
failure_tolerance
>
0
and
self
.
_last_batch_failures
>=
self
.
failure_tolerance
:
raise
FatalSampleError
.
from_sample
(
batch
,
f
"GroupBatchDataset
{
self
.
batcher
}
failed
{
self
.
_last_batch_failures
}
times in a row. Likely your code or dataset are broken."
,
)
return
batch_sample
def
config
(
self
)
->
Dict
[
str
,
Any
]:
return
{
"type"
:
type
(
self
).
__qualname__
,
"bucket"
:
self
.
_function_config
(
self
.
sample_group_key
),
"batcher"
:
self
.
_function_config
(
self
.
batcher
),
**
(
{
"batcher_config"
:
(
self
.
batcher_config
()
if
callable
(
self
.
batcher_config
)
else
self
.
batcher_config
)
}
if
self
.
batcher_config
else
{}
),
"batcher_stateless"
:
self
.
batcher_stateless
,
"drop_last"
:
self
.
drop_last
,
"error_handler"
:
self
.
_function_config
(
self
.
error_handler
),
"worker_config"
:
self
.
worker_config
.
config
(),
"dataset"
:
self
.
dataset
.
config
(),
}
def
__str__
(
self
):
return
f
"GroupBatchDataset(bucket=
{
self
.
sample_group_key
}
, batcher=
{
self
.
batcher
}
, drop_last=
{
self
.
drop_last
}
, dataset=
{
self
.
dataset
}
)"
Megatron-Energon/src/megatron/energon/wrappers/iter_map_dataset.py
0 → 100644
View file @
f356f546
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from
typing
import
(
Any
,
Callable
,
Dict
,
Generator
,
Generic
,
Iterator
,
Optional
,
Tuple
,
TypeVar
,
Union
,
)
from
torch.utils.data
import
IterableDataset
from
megatron.energon.errors
import
SYSTEM_EXCEPTIONS
,
FatalSampleError
from
megatron.energon.flavors.base_dataset
import
SavableDataset
,
set_sample_restore_key
from
megatron.energon.source_info
import
SourceInfo
from
megatron.energon.worker
import
WorkerConfig
from
megatron.energon.wrappers._log_exception
import
log_exception
from
megatron.energon.wrappers.base
import
BaseWrapperDataset
,
SampleIndex
,
get_sample_restore_key
T_sample
=
TypeVar
(
"T_sample"
)
T_sample_out
=
TypeVar
(
"T_sample_out"
)
class
IterMapDataset
(
BaseWrapperDataset
[
T_sample
,
T_sample_out
],
Generic
[
T_sample
,
T_sample_out
]):
"""This dataset wrapper applies a custom function to transform the stream of samples and yield
a new stream of samples.
If used in a savable dataset context, it is critical, that `iter_map_fn` is either stateless,
or that the state of the `iter_map_fn` is saved and restored externally.
"""
iter_map_fn
:
Callable
[[
Iterator
[
T_sample
]],
Iterator
[
T_sample_out
]]
len_map_fn
:
Callable
[[
int
],
int
]
error_handler
:
Callable
[[
Exception
,
Optional
[
T_sample
],
list
[
SourceInfo
]],
None
]
stateless_iter_fn
:
bool
iter_map_fn_config
:
Optional
[
Union
[
Dict
[
str
,
Any
],
Callable
[[],
Dict
[
str
,
Any
]]]]
_sample_index
:
SampleIndex
_savable_fields
=
(
"_sample_index"
,)
def
__init__
(
self
,
dataset
:
SavableDataset
[
T_sample
],
iter_map_fn
:
Callable
[[
Iterator
[
T_sample
]],
Iterator
[
T_sample_out
]],
*
,
len_map_fn
:
Callable
[[
int
],
int
]
=
lambda
x
:
x
,
error_handler
:
Callable
[
[
Exception
,
Optional
[
T_sample
],
list
[
SourceInfo
]],
None
]
=
log_exception
,
stateless_iter_fn
:
bool
=
False
,
iter_map_fn_config
:
Optional
[
Union
[
Dict
[
str
,
Any
],
Callable
[[],
Dict
[
str
,
Any
]]]]
=
None
,
worker_config
:
WorkerConfig
,
):
"""Construct a IterMapDataset.
For saving and restoring samples, the iter_map_fn must only yield 0 or 1 sample per
iterated sample.
Args:
dataset: The input dataset to wrap
iter_map_fn: The function to apply to the stream of samples. Returns a new stream of
samples. If savability should be preserved, this function should be stateless.
len_map_fn: The function to apply to the length of the dataset. Returns the new
(approximate) length of the resulting stream of samples based on the original
length.
error_handler: Handler for errors. Defaults to logging and ignoring the exception.
stateless_iter_fn: If true, assume the iter_map_fn is deterministic and stateless
(it does not aggregate samples (thus key for random access can propagate to inner
dataset), yielding zero or multiple samples per fetched sample is fine).
Defaults to False.
iter_map_fn_config: Configuration for the iter_map_fn function. If callable, it should return the
configuration. Defaults to None.
worker_config: Configuration for the workers.
"""
super
().
__init__
(
dataset
,
worker_config
=
worker_config
)
self
.
iter_map_fn
=
iter_map_fn
self
.
len_map_fn
=
len_map_fn
self
.
error_handler
=
error_handler
self
.
stateless_iter_fn
=
stateless_iter_fn
self
.
iter_map_fn_config
=
iter_map_fn_config
self
.
reset_state_own
()
def
reset_state_own
(
self
)
->
None
:
self
.
_sample_index
=
SampleIndex
(
self
.
worker_config
,
src
=
self
)
def
len_worker
(
self
,
worker_idx
:
int
|
None
=
None
)
->
int
:
return
self
.
len_map_fn
(
self
.
dataset
.
len_worker
(
worker_idx
))
def
__iter__
(
self
)
->
Iterator
[
T_sample_out
]:
last_sample_wrapper
=
_LastSampleWrapper
(
self
.
dataset
)
# The iter_map_fn is stateless. Thus we need to know which inner sample created the
# outer sample, and the relative outer sample index, so we can restore it.
# This is the sample index within the currently yielded sample
iter_idx
=
0
sample_idx
=
0
sample_restore_keys
=
[]
def
reset_idx_iter
()
->
Generator
[
T_sample
,
None
,
None
]:
# Resets the inner sample index
nonlocal
iter_idx
,
sample_restore_keys
for
entry
in
last_sample_wrapper
:
iter_idx
=
0
sample_restore_keys
.
append
(
get_sample_restore_key
(
entry
))
yield
entry
ds_iter
=
iter
(
reset_idx_iter
())
# While True will break when the inner dataset is exhausted, but may continue on exception
while
True
:
iter_idx
=
0
try
:
for
sample_idx
,
sample
in
self
.
_sample_index
.
iter_ctx
(
self
.
iter_map_fn
(
ds_iter
)):
yield
set_sample_restore_key
(
sample
,
sample_idx
,
iter_idx
,
*
sample_restore_keys
,
src
=
self
,
)
sample_restore_keys
.
clear
()
iter_idx
+=
1
except
SYSTEM_EXCEPTIONS
:
raise
FatalSampleError
.
from_sample
(
last_sample_wrapper
.
last_sample
)
except
Exception
as
e
:
self
.
error_handler
(
e
,
last_sample_wrapper
.
last_sample
)
else
:
break
def
can_restore_sample
(
self
)
->
bool
:
return
super
().
can_restore_sample
()
and
self
.
stateless_iter_fn
def
assert_can_restore
(
self
)
->
None
:
assert
self
.
stateless_iter_fn
,
(
"IterMapDataset can only restore samples if iter_map_fn is stateless."
)
super
().
assert_can_restore
()
def
restore_sample
(
self
,
restore_key
:
Tuple
[
Union
[
str
,
int
,
tuple
],
...])
->
T_sample
:
self
.
assert_can_restore
()
id
,
sample_idx
,
iter_idx
,
*
sample_restore_keys
=
restore_key
assert
id
==
type
(
self
).
__name__
assert
isinstance
(
iter_idx
,
int
)
to_be_mapped
=
(
self
.
dataset
.
restore_sample
(
inner_index
)
for
inner_index
in
sample_restore_keys
)
try
:
inner_iter
=
iter
(
self
.
iter_map_fn
(
to_be_mapped
))
# Skip inner yielded samples to get the correct sample
for
skip_idx
in
range
(
iter_idx
):
with
self
.
_sample_index
.
ctx
(
sample_idx
-
iter_idx
+
skip_idx
):
next
(
inner_iter
)
# This is the sample to restore
with
self
.
_sample_index
.
ctx
(
sample_idx
):
sample
=
next
(
inner_iter
)
return
set_sample_restore_key
(
sample
,
sample_idx
,
iter_idx
,
*
sample_restore_keys
,
src
=
self
,
)
except
StopIteration
:
raise
RuntimeError
(
"Generator did not yield enough samples, but is marked stateless/deterministic."
)
except
GeneratorExit
:
raise
FatalSampleError
.
from_sample
(
to_be_mapped
,
f
"IterMapDataset
{
self
.
iter_map_fn
}
generator exited while trying to restore a sample."
,
)
except
SYSTEM_EXCEPTIONS
:
raise
FatalSampleError
.
from_sample
(
to_be_mapped
)
except
Exception
as
e
:
self
.
error_handler
(
e
,
to_be_mapped
)
finally
:
# Properly close if it's a generator
if
hasattr
(
inner_iter
,
"close"
):
inner_iter
.
close
()
def
config
(
self
)
->
Dict
[
str
,
Any
]:
return
{
"type"
:
type
(
self
).
__qualname__
,
"dataset"
:
self
.
dataset
.
config
(),
"iter_map_fn"
:
self
.
_function_config
(
self
.
iter_map_fn
),
**
(
{
"iter_map_fn_config"
:
(
self
.
iter_map_fn_config
()
if
callable
(
self
.
iter_map_fn_config
)
else
self
.
iter_map_fn_config
)
}
if
self
.
iter_map_fn_config
else
{}
),
"len_map_fn"
:
self
.
_function_config
(
self
.
len_map_fn
),
"error_handler"
:
self
.
_function_config
(
self
.
error_handler
),
}
def
__str__
(
self
):
return
f
"IterMapDataset(iter_map_fn=
{
self
.
iter_map_fn
}
, dataset=
{
self
.
dataset
}
)"
class
_LastSampleWrapper
:
"""
Wraps the inner dataset and stores the last iterated sample.
"""
last_sample
:
Optional
[
T_sample
]
=
None
dataset
:
IterableDataset
[
T_sample
]
def
__init__
(
self
,
dataset
:
IterableDataset
[
T_sample
]):
self
.
dataset
=
dataset
def
__iter__
(
self
)
->
Iterator
[
T_sample
]:
for
sample
in
self
.
dataset
:
self
.
last_sample
=
sample
yield
sample
Megatron-Energon/src/megatron/energon/wrappers/limit_dataset.py
0 → 100644
View file @
f356f546
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from
typing
import
Any
,
Dict
,
Generic
,
Iterator
,
TypeVar
from
megatron.energon.flavors.base_dataset
import
SavableDataset
from
megatron.energon.worker
import
WorkerConfig
from
megatron.energon.wrappers.base
import
BaseWrapperDataset
T_sample
=
TypeVar
(
"T_sample"
)
class
LimitDataset
(
BaseWrapperDataset
[
T_sample
,
T_sample
],
Generic
[
T_sample
]):
"""Limits the length of the dataset."""
length
:
int
current_offset
:
int
_savable_fields
=
(
"current_offset"
,)
def
__init__
(
self
,
dataset
:
SavableDataset
[
T_sample
],
length
:
int
,
*
,
reset_after_epoch
:
bool
=
False
,
worker_config
:
WorkerConfig
,
):
"""
Limits the length of the dataset.
Args:
dataset: The dataset to limit
length: The length to limit to
reset_after_epoch: If true, reset the underlying dataset after one epoch.
worker_config: Configuration for the workers.
"""
super
().
__init__
(
dataset
,
worker_config
=
worker_config
)
self
.
length
=
length
self
.
reset_after_epoch
=
reset_after_epoch
self
.
reset_state_own
()
def
reset_state_own
(
self
)
->
None
:
self
.
current_offset
=
0
def
len_worker
(
self
,
worker_idx
:
int
|
None
=
None
)
->
int
:
if
worker_idx
is
None
:
self
.
worker_config
.
assert_worker
()
worker_idx
=
self
.
worker_config
.
rank_worker_id
()
if
self
.
worker_config
.
num_workers
<=
1
:
return
self
.
length
else
:
local_limit
=
self
.
length
//
self
.
worker_config
.
num_workers
if
worker_idx
<
self
.
length
%
self
.
worker_config
.
num_workers
:
local_limit
+=
1
return
local_limit
def
len_rank
(
self
)
->
int
:
return
min
(
self
.
length
,
self
.
dataset
.
len_rank
())
def
__iter__
(
self
)
->
Iterator
[
T_sample
]:
worker_id
=
self
.
worker_config
.
rank_worker_id
()
# Compute the local limit for this worker, i.e. all worker's limits sum up to the total
if
self
.
worker_config
.
num_workers
<=
1
:
local_limit
=
self
.
length
else
:
local_limit
=
self
.
length
//
self
.
worker_config
.
num_workers
if
worker_id
<
self
.
length
%
self
.
worker_config
.
num_workers
:
local_limit
+=
1
if
self
.
worker_config
.
should_log
(
level
=
2
):
self
.
worker_config
.
worker_log
(
{
"t"
:
"LimitDataset.start"
,
"r"
:
self
.
worker_config
.
rank
,
"w"
:
worker_id
,
"offset"
:
self
.
current_offset
,
"local_limit"
:
local_limit
,
"limit"
:
self
.
length
,
}
)
offset_range
=
list
(
range
(
self
.
current_offset
,
local_limit
))
# Only iterate self.dataset if there are samples to iterate
if
len
(
offset_range
)
>
0
:
for
sample
,
offset
in
zip
(
self
.
dataset
,
offset_range
,
):
self
.
current_offset
=
offset
+
1
yield
sample
if
self
.
worker_config
.
should_log
(
level
=
2
):
self
.
worker_config
.
worker_log
(
{
"t"
:
"LimitDataset.done"
,
"r"
:
self
.
worker_config
.
rank
,
"w"
:
worker_id
,
"offset"
:
self
.
current_offset
,
"local_limit"
:
local_limit
,
"limit"
:
self
.
length
,
}
)
# Reset the inner dataset
self
.
dataset
.
reset_state_deep
()
self
.
current_offset
=
0
if
self
.
reset_after_epoch
:
self
.
dataset
.
reset_state_deep
()
def
worker_has_samples
(
self
)
->
bool
:
return
super
().
worker_has_samples
()
and
self
.
length
>
0
def
config
(
self
)
->
Dict
[
str
,
Any
]:
return
{
"type"
:
type
(
self
).
__qualname__
,
"dataset"
:
self
.
dataset
.
config
(),
"length"
:
self
.
length
,
"reset_after_epoch"
:
self
.
reset_after_epoch
,
"worker_config"
:
self
.
worker_config
.
config
(),
}
def
__str__
(
self
):
return
f
"LimitDataset(length=
{
self
.
length
}
, dataset=
{
self
.
dataset
}
)"
Prev
1
…
5
6
7
8
9
10
11
12
13
…
18
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment