date_features.py 2.29 KB
Newer Older
bailuo's avatar
readme  
bailuo committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
__all__ = ['CountryHolidays', 'SpecialDates']

import pandas as pd

def _transform_dict_holidays(dict_holidays_dates):
    dict_holidays = {}
    for key, value in dict_holidays_dates.items():
        if value not in dict_holidays:
            dict_holidays[value] = []
        dict_holidays[value].append(key)
    return dict_holidays

def _get_holidays_df(dates, categories, holiday_extractor, supported_categories):
    years = dates.year.unique().tolist()
    total_holidays = dict()
    for cat in categories:
        if cat not in supported_categories:
            raise Exception(f"Holidays for {cat} not available, please remove it.")
        dict_holidays = _transform_dict_holidays(holiday_extractor(cat, years=years))
        for key, val in dict_holidays.items():
            total_holidays[f"{cat}_{key}"] = [int(ds.date() in val) for ds in dates]
    return pd.DataFrame(total_holidays, index=dates)

class CountryHolidays:
    """Given a list of countries, returns a dataframe with holidays for each country."""

    def __init__(self, countries: list[str]):
        self.countries = countries

    def __call__(self, dates: pd.DatetimeIndex):
        try:
            from holidays.utils import country_holidays
            from holidays.utils import list_supported_countries
        except ModuleNotFoundError:
            raise Exception(
                "You have to install additional libraries to use holidays, "
                'please install them using `pip install "nixtla[date_extras]"`'
            )
        return _get_holidays_df(
            dates, self.countries, country_holidays, list_supported_countries()
        )

    def __name__(self):
        return "CountryHolidays"

class SpecialDates:
    """Given a dictionary of categories and dates, returns a dataframe with the special dates."""

    def __init__(self, special_dates: dict[str, list[str]]):
        self.special_dates = special_dates

    def __call__(self, dates: pd.DatetimeIndex):
        total_special_dates = dict()
        for key, val in self.special_dates.items():
            date_vals = [ds.date() for ds in pd.to_datetime(val)]
            total_special_dates[key] = [int(ds.date() in date_vals) for ds in dates]
        return pd.DataFrame(total_special_dates, index=dates)

    def __name__(self):
        return "SpecialDates"